Skip to content

Commit 76e2b46

Browse files
committed
Update
Lint fix
1 parent 9b1450a commit 76e2b46

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
# Change this flag to False to directly convert to `nn.batch_matmul`.
5757
# Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
5858
# performance issues.
59-
"use_nt_batch_matmul_op": True,
59+
"use_nt_batch_matmul": True,
6060
}
6161

6262
# compatible operators that do NOT require any conversion.
@@ -1219,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12191219
return func, self._params
12201220

12211221

1222-
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config={}):
1222+
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_config=None):
12231223
"""Load tensorflow graph which is a python tensorflow graph object into relay.
12241224
The companion parameters will be handled automatically.
12251225
@@ -1237,13 +1237,13 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, convert_conf
12371237
outputs : List of output tensor names (Optional)
12381238
if not specified then the last node is assumed as graph output.
12391239
1240-
convert_config : Dict[str, Any]
1240+
convert_config : Optional[Dict[str, Any]]
12411241
Default config:
1242-
use_dense_op : bool = True
1242+
use_dense : bool = True
12431243
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
12441244
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor
12451245
to be transposed, may insert extra `transpose` to the original graph.
1246-
use_nt_batch_matmul_op : bool = True
1246+
use_nt_batch_matmul : bool = True
12471247
True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format
12481248
(transpose_a=False, transpose_b=True).
12491249

python/tvm/relay/frontend/tensorflow_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def _impl(inputs, attr, params, mod):
11761176
adj_x = attr["adj_x"]
11771177
adj_y = attr["adj_y"]
11781178

1179-
if TF_DEFAULT_CONFIGS["use_nt_batch_matmul_op"]:
1179+
if TF_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
11801180
# Strictly convert all batch_matmul to NT format
11811181
input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x
11821182
input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y

python/tvm/topi/nn/batch_matmul.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
# under the License.
1717
"""Batch matrix multiplication"""
1818
# pylint: disable=invalid-name
19+
import logging
1920
import tvm
2021
from tvm import te, auto_scheduler
2122
from ..utils import get_const_tuple
2223

24+
logger = logging.getLogger("topi")
25+
2326

2427
def batch_matmul(
2528
tensor_a,
@@ -94,6 +97,12 @@ def batch_matmul(
9497
oshape = (batch, XI, YJ)
9598
if out_dtype is None:
9699
out_dtype = tensor_a.dtype
100+
if tensor_a.dtype != tensor_b.dtype:
101+
logger.warning(
102+
"tensor_a has different data type with tensor_b: %s, %s",
103+
tensor_a.dtype,
104+
tensor_b.dtype,
105+
)
97106

98107
if (transpose_a, transpose_b) == (True, True):
99108
compute_lambda = lambda b, i, j: te.sum(

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def run_tvm_graph(
124124
disabled_pass=None,
125125
ignore_in_shape=False,
126126
serialize=False,
127-
convert_config={},
127+
convert_config=None,
128128
):
129129
"""Generic function to compile on relay and execute on tvm"""
130130
input_data = convert_to_list(input_data)
@@ -225,7 +225,7 @@ def compare_tf_with_tvm(
225225
add_shapes_to_graph_def=True,
226226
targets=None,
227227
ignore_in_shape=False,
228-
convert_config={},
228+
convert_config=None,
229229
):
230230
"""Generic function to generate and compare tensorflow and TVM output"""
231231

@@ -1811,8 +1811,12 @@ def _test_matmul(i, j, k, dtype, outer=None):
18111811

18121812
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
18131813
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
1814-
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, {"use_dense_op": True})
1815-
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name, {"use_dense_op": False})
1814+
compare_tf_with_tvm(
1815+
[A_np, B_np], [A.name, B.name], result.name, {"use_dense": True}
1816+
)
1817+
compare_tf_with_tvm(
1818+
[A_np, B_np], [A.name, B.name], result.name, {"use_dense": False}
1819+
)
18161820

18171821

18181822
def test_forward_matmul():
@@ -1831,10 +1835,10 @@ def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False
18311835
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
18321836
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
18331837
compare_tf_with_tvm(
1834-
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul_op": True}
1838+
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul": True}
18351839
)
18361840
compare_tf_with_tvm(
1837-
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul_op": False}
1841+
[A_np, B_np], [A.name, B.name], result.name, {"use_nt_batch_matmul": False}
18381842
)
18391843

18401844

0 commit comments

Comments
 (0)