Skip to content

Commit

Permalink
quantize can not get QlinearConcat work??!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Apr 17, 2022
1 parent 8a50868 commit 1b238de
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 24 deletions.
Binary file added .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions deploy/ort_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def load_test_image(f, h, w):

if "sparse" in args.model:
masks = output[0][0]
masks = np.squeeze(masks, axis=1)
scores = output[1][0]
labels = output[2][0]
keep = scores > 0.3
Expand Down
2 changes: 1 addition & 1 deletion deploy/quant_onnx/qt_atom_sparseinst.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def pqt(onnx_f):
input_shape = session.get_inputs()[0].shape

calib_dataloader = get_calib_dataloader_coco(
coco_root, anno_f, preprocess_func=preprocess_func, input_names=input_name, bs=1
coco_root, anno_f, preprocess_func=preprocess_func, input_names=input_name, bs=1, max_step=50
)
quantize_static_onnx(onnx_f, calib_dataloader=calib_dataloader)

Expand Down
8 changes: 5 additions & 3 deletions export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def vis_res_fast(res, img, colors):

def get_model_infos(config_file):
if "sparse_inst" in config_file:
output_names = ["masks", "scores", "labels"]
output_names = ["masks", "labels"]
# output_names = ["masks", "scores", "labels"]
output_names = ["masks", "scores"]
input_names = ["images"]
dynamic_axes = {"images": {0: "batch"}}
return input_names, output_names, dynamic_axes
Expand Down Expand Up @@ -300,7 +300,9 @@ def get_model_infos(config_file):

# use onnxsimplify to reduce reduent model.
sim_onnx = onnx_f.replace(".onnx", "_sim.onnx")
os.system(f"python3 -m onnxsim {onnx_f} {sim_onnx} --dynamic-input-shape --input-shape 1,{h},{w},3")
os.system(
f"python3 -m onnxsim {onnx_f} {sim_onnx} --dynamic-input-shape --input-shape 1,{h},{w},3"
)
logger.info("generate simplify onnx to: {}".format(sim_onnx))
if "detr" in sim_onnx:
# this is need for detr onnx model
Expand Down
41 changes: 21 additions & 20 deletions yolov7/modeling/meta_arch/sparseinst.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, cfg):
)
self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
# only for onnx export
self.normalizer_trans = lambda x: (x - self.pixel_mean) / self.pixel_std
self.normalizer_trans = lambda x: (x - self.pixel_mean.unsqueeze(0)) / self.pixel_std.unsqueeze(0)

# inference
self.cls_threshold = cfg.MODEL.SPARSE_INST.CLS_THRESHOLD
Expand Down Expand Up @@ -94,11 +94,8 @@ def prepare_targets(self, targets):
return new_targets

def preprocess_inputs_onnx(self, x):
x = [xx.permute(2, 0, 1) for xx in x]
# print(x.shape)
# x = F.interpolate(x, size=(640, 640))
# x = F.interpolate(x, size=(512, 960))
x = [self.normalizer_trans(xx) for xx in x]
x = x.permute(0, 3, 1, 2)
x = self.normalizer_trans(x)
return x

def forward(self, batched_inputs):
Expand All @@ -108,18 +105,20 @@ def forward(self, batched_inputs):
batched_inputs, list
), "onnx export, batched_inputs only needs image tensor or list of tensors"
images = self.preprocess_inputs_onnx(batched_inputs)
logger.info(f'images onnx input: {images.shape}')
else:
images = self.preprocess_inputs(batched_inputs)

if isinstance(images, (list, torch.Tensor)):
images = nested_tensor_from_tensor_list(images)
# if isinstance(images, (list, torch.Tensor)):
# images = nested_tensor_from_tensor_list(images)

if isinstance(images, ImageList):
max_shape = images.tensor.shape[2:]
features = self.backbone(images.tensor)
else:
max_shape = images.tensors.shape[2:]
features = self.backbone(images.tensors)
# onnx trace
max_shape = images.shape[2:]
features = self.backbone(images)

features = self.encoder(features)
output = self.decoder(features)
Expand All @@ -132,7 +131,7 @@ def forward(self, batched_inputs):
else:
if torch.onnx.is_in_onnx_export():
results = self.inference_onnx(
output, batched_inputs, max_shape, images.image_sizes
output, batched_inputs, max_shape
)
return results
else:
Expand Down Expand Up @@ -208,7 +207,7 @@ def inference(self, output, batched_inputs, max_shape, image_sizes):
results.append(result)
return results

def inference_onnx(self, output, batched_inputs, max_shape, image_sizes):
def inference_onnx(self, output, batched_inputs, max_shape):
# max_detections = self.max_detections
pred_scores = output["pred_logits"].sigmoid()
pred_masks = output["pred_masks"].sigmoid()
Expand All @@ -219,12 +218,13 @@ def inference_onnx(self, output, batched_inputs, max_shape, image_sizes):
all_scores = []
all_labels = []
all_masks = []
print('max_shape: ', max_shape)

for _, (
scores_per_image,
mask_pred_per_image,
batched_input,
img_shape,
) in enumerate(zip(pred_scores, pred_masks, batched_inputs, image_sizes)):
) in enumerate(zip(pred_scores, pred_masks, batched_inputs)):

# max/argmax
scores, labels = torch.max(scores_per_image, dim=scores_per_image.dim()-1)
Expand All @@ -237,7 +237,7 @@ def inference_onnx(self, output, batched_inputs, max_shape, image_sizes):
# print(scores, labels)
mask_pred_per_image = mask_pred_per_image[keep]

h, w = img_shape
h, w = max_shape
# rescoring mask using maskness
scores = rescoring_mask(
scores, mask_pred_per_image > self.mask_threshold, mask_pred_per_image
Expand All @@ -255,15 +255,16 @@ def inference_onnx(self, output, batched_inputs, max_shape, image_sizes):
)[:, :h, :w]

mask_pred = mask_pred_per_image > self.mask_threshold

all_masks.append(mask_pred)
all_scores.append(scores)
all_labels.append(labels)
all_masks.append(mask_pred)


all_masks = torch.stack(all_masks).to(torch.long)
all_scores = torch.stack(all_scores)
all_labels = torch.stack(all_labels)
all_masks = torch.stack(all_masks).to(torch.long)
# logger.info(f'all_scores: {all_scores.shape}')
logger.info(f'all_labels: {all_labels.shape}')
# logger.info(f'all_labels: {all_labels.shape}')
logger.info(f'all_masks: {all_masks.shape}')
return all_masks, all_scores, all_labels
return all_masks, all_scores
# return all_masks, all_labels
3 changes: 3 additions & 0 deletions yolov7/modeling/transcoders/encoder_sparseinst.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def forward(self, x):
assert len(self.sz) == 2
kernel_width = math.ceil(inp_size[2] / self.sz[0])
kernel_height = math.ceil(inp_size[3] / self.sz[1])
if torch.is_tensor(kernel_width):
kernel_width = kernel_width.item()
kernel_height = kernel_height.item()
return F.avg_pool2d(
input=x, ceil_mode=False, kernel_size=(kernel_width, kernel_height)
)
Expand Down

0 comments on commit 1b238de

Please sign in to comment.