Skip to content

Commit 8afeaf6

Browse files
masahizhiics
authored andcommitted
[Quantization, Calibrate] Fix context creation when current_target is explicity set (apache#4582)
1 parent dbd9ee9 commit 8afeaf6

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

python/tvm/relay/quantize/_calibrate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,18 @@ def collect_stats(mod, dataset):
5353
logging.info("collecting statistics for calibration...")
5454
func = mod['main']
5555
func = _quantize.CreateStatsCollector(func)
56-
target = tvm.target.current_target() or 'llvm'
56+
57+
if tvm.target.current_target():
58+
target = tvm.target.current_target()
59+
ctx = tvm.context(target.target_name)
60+
else:
61+
target = 'llvm'
62+
ctx = tvm.context(target)
63+
5764
with _transform.build_config(opt_level=3):
5865
graph, lib, params = _build_module.build(func, target=target)
5966
outputs = []
60-
runtime = graph_runtime.create(graph, lib, tvm.context(target))
67+
runtime = graph_runtime.create(graph, lib, ctx)
6168
runtime.set_input(**params)
6269

6370
num_outputs = runtime.get_num_outputs()

tests/python/relay/test_pass_auto_quantize.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import numpy as np
1718
import tvm
1819
from tvm import relay
1920
from tvm.relay import testing
@@ -45,5 +46,28 @@ def test_mul_rewrite():
4546

4647
quantize_and_build(act * pool)
4748

49+
50+
def get_calibration_dataset(input_name):
51+
dataset = []
52+
for i in range(5):
53+
data = np.random.uniform(size=(1, 3, 224, 224))
54+
dataset.append({input_name: data})
55+
return dataset
56+
57+
58+
def test_calibrate_target(create_target=False):
59+
mod, params = testing.resnet.get_workload(num_layers=18)
60+
dataset = get_calibration_dataset("data")
61+
with relay.quantize.qconfig(calibrate_mode="kl_divergence"):
62+
if create_target:
63+
with tvm.target.create("llvm"):
64+
relay.quantize.quantize(mod, params, dataset)
65+
else:
66+
# current_target = None
67+
relay.quantize.quantize(mod, params, dataset)
68+
69+
4870
if __name__ == "__main__":
4971
test_mul_rewrite()
72+
test_calibrate_target(False)
73+
test_calibrate_target(True)

0 commit comments

Comments
 (0)