Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 6dd7b66

Browse files
author
jagadeesh
committed
replace iterator with run method
Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>
1 parent 18b0f59 commit 6dd7b66

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

examples/nvidia_dali/custom_handler.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
import torch
1111
from nvidia.dali.pipeline import Pipeline
12-
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
1312

1413
from ts.torch_handler.image_classifier import ImageClassifier
1514

@@ -32,7 +31,6 @@ def initialize(self, context):
3231
]
3332
if not len(self.dali_file):
3433
raise RuntimeError("Missing dali pipeline file.")
35-
self.PREFETCH_QUEUE_DEPTH = 2
3634
dali_config_file = os.path.join(self.model_dir, "dali_config.json")
3735
if not os.path.isfile(dali_config_file):
3836
raise RuntimeError("Missing dali_config.json file.")
@@ -43,6 +41,7 @@ def initialize(self, context):
4341
filename=dali_filename,
4442
batch_size=self.dali_configs["batch_size"],
4543
num_threads=self.dali_configs["num_threads"],
44+
prefetch_queue_depth=1,
4645
device_id=self.dali_configs["device_id"],
4746
seed=self.dali_configs["seed"],
4847
)
@@ -62,27 +61,16 @@ def preprocess(self, data):
6261
list : The preprocess function returns the input image as a list of float tensors.
6362
"""
6463
batch_tensor = []
64+
result = []
6565

6666
input_byte_arrays = [i["body"] if "body" in i else i["data"] for i in data]
6767
for byte_array in input_byte_arrays:
6868
np_image = np.frombuffer(byte_array, dtype=np.uint8)
6969
batch_tensor.append(np_image) # we can use numpy
7070

71-
for _ in range(self.PREFETCH_QUEUE_DEPTH + 1):
72-
self.pipe.feed_input("source", batch_tensor)
73-
74-
datum = DALIGenericIterator(
75-
[self.pipe],
76-
["data"],
77-
last_batch_policy=LastBatchPolicy.PARTIAL,
78-
last_batch_padded=True,
79-
)
80-
81-
result = []
82-
for count, data in enumerate(datum):
83-
self.pipe.feed_input("source", batch_tensor)
84-
result.append(data[0]["data"])
85-
if count == len(input_byte_arrays) - 1:
86-
break
71+
response = self.pipe.run(source=batch_tensor)
72+
for idx, _ in enumerate(response[0]):
73+
data = torch.tensor(response[0].at(idx))
74+
result.append(data.unsqueeze(0))
8775

8876
return torch.cat(result).to(self.device)

0 commit comments

Comments
 (0)