Skip to content

Commit 6715d4c

Browse files
authored
Shape inference: GatherBlockQuantized dispatcher (#23748)
### Description Add shape infer dispatcher for `GatherBlockQuantized` contrib op. It reuses the dispatcher for `Gather` op since the first two inputs have the same specs. The output elem type comes from input 2 (scales) for `GatherBlockQuantized`. ### Motivation and Context Support shape inference for models with `GatherBlockQuantized` op.
1 parent 75cf166 commit 6715d4c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

onnxruntime/python/tools/symbolic_shape_infer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
202202
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
203203
"FastGelu": self._infer_FastGelu,
204204
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
205+
"GatherBlockQuantized": self._infer_Gather,
205206
"Gelu": self._infer_Gelu,
206207
"GemmFastGelu": self._infer_GemmFastGelu,
207208
"GemmFloat8": self._infer_GemmFloat8,
@@ -459,6 +460,7 @@ def _onnx_infer_single_node(self, node):
459460
"BiasGelu",
460461
"EmbedLayerNormalization",
461462
"FastGelu",
463+
"GatherBlockQuantized",
462464
"Gelu",
463465
"GemmFastGelu",
464466
"LayerNormalization",
@@ -1118,10 +1120,17 @@ def _infer_Gather(self, node): # noqa: N802
11181120
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
11191121
indices_shape = self._get_shape(node, 1)
11201122
vi = self.known_vi_[node.output[0]]
1123+
if node.op_type == "Gather":
1124+
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1125+
elif node.op_type == "GatherBlockQuantized":
1126+
# scales
1127+
elem_type = self.known_vi_[node.input[2]].type.tensor_type.elem_type
1128+
else:
1129+
raise ValueError(f"Unsupported Gather op_type: {node.op_type}")
11211130
vi.CopyFrom(
11221131
helper.make_tensor_value_info(
11231132
node.output[0],
1124-
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1133+
elem_type,
11251134
data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
11261135
)
11271136
)

0 commit comments

Comments
 (0)