From 282c532a1d6fc1209426afb82c2eeead18317527 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 29 Jun 2021 04:17:57 -0700 Subject: [PATCH] [AMP] Turn off accumulation data types for mixed precision pass (#8341) * don't use mixed precision accumulators * turn off fp32 accumulators for now, adjust passing test cases * Add TODO on cuda codegen for failures. Make test case pass on cuda for now test to mixed precision more tests add internal func call broadcast failures moreee add comment and change lstm unit test to pass on cuda * remove debug statements * to mixed precision * rebase main * rtol and atol adjustments * bump up tolerance again * jostle CI --- python/tvm/relay/transform/mixed_precision.py | 15 +--- tests/python/relay/test_op_level10.py | 9 +- tests/python/relay/test_to_mixed_precision.py | 84 +++++++++---------- 3 files changed, 46 insertions(+), 62 deletions(-) diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6aa3ac09cfee..6f8ecb970221 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -40,7 +40,7 @@ "nn.conv2d_transpose", "nn.conv3d_transpose", "nn.dense", - # "nn.batch_matmul", # Handled by a special case + "nn.batch_matmul", ] DEFAULT_FOLLOW_LIST = [ # These ops add new data or change shape @@ -162,7 +162,9 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> # Some discussion here about making this better is here: # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo if hasattr(call_node.attrs, "out_dtype"): - return ["float32", mixed_precision_type] + # TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators + # return ["float32", mixed_precision_type] + return [mixed_precision_type, mixed_precision_type] # [accumulation_dtype, output_dtype] for the operations return [mixed_precision_type, mixed_precision_type] @@ -184,12 +186,3 @@ def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) - - -@register_mixed_precision_conversion("nn.batch_matmul") -def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: - # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. - # Batched matmul has inconsistent support for mixed precision operations. - # Many schedules ignore the out_dtype attribute which leads to errors when - # input types do not match the out_dtype. Therefore, accumulate to output_dtype. - return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0eddd965c661..24f0ed6642b5 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -18,14 +18,11 @@ """ import numpy as np import tvm -from tvm import te +import tvm.testing import tvm.topi.testing -from tvm import relay +from tvm import relay, te, topi from tvm.relay import transform from tvm.relay.testing import run_infer_type -from tvm import topi -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -608,7 +605,7 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) - tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-6, atol=1e-6) _verify((10, 5)) _verify((10, 5, 2, 2)) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index caccd52d60c2..7a3fbfafc089 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -48,6 +48,7 @@ def verify_mixed_precision_output_close( result_fp32 = run_module(mod, mod_params) fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) + # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) @@ -60,7 +61,9 @@ def test_lstm(): Has internal functions and let statements the pass must work on. """ - units = 3 + # TODO(AndrewZhaoLuo): investigate why non-even units cause failure in codegen for CUDA + # See discussion here: https://github.com/apache/tvm/issues/8294#issuecomment-866190408 + units = 4 iterations = 5 mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) @@ -118,16 +121,13 @@ def test_convert_single_conv(): fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) expected_mod = tvm.IRModule.from_expr( - relay.cast( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float16", - ) + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), ) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -156,16 +156,13 @@ def test_convert_single_conv_fp64(): # Note we still accumulate to FP32 by default, a user would need to overwrite default # behavior to make this make more sense. expected_mod = tvm.IRModule.from_expr( - relay.cast( - relay.nn.conv2d( - relay.cast(data, "float64"), - relay.cast(weight, "float64"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float64", - ) + relay.nn.conv2d( + relay.cast(data, "float64"), + relay.cast(weight, "float64"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float64", + ), ) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -198,15 +195,12 @@ def test_convert_conv_bn(): "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.025, rtol=0.01) # Creating expected module data = relay.cast(relay.var("data", shape=data_shape), "float16") weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") - conv = relay.cast( - relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"), - "float16", - ) + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float16") bn_shape = [5] gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") @@ -254,17 +248,14 @@ def test_green_gray_propagates_simple(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) - conv_expr = relay.cast( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float16", + conv_expr = relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", ) expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -316,12 +307,15 @@ def test_green_red_not_use_extraneous_cast(): fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) # Construct expected structure - conv = relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", + conv = relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), + "float32", ) result = relay.nn.softmax(conv) expected_mod = tvm.IRModule.from_expr(result) @@ -380,12 +374,12 @@ def test_let_statement_simple(): r2 = var2 + var2 let2 = relay.Let( var2, - relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"), + relay.nn.dense(r1, weight, units=20, out_dtype="float16"), r2, ) let1 = relay.Let( var1, - relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"), + relay.nn.dense(data, weight, units=20, out_dtype="float16"), let2, ) expected_mod = tvm.IRModule.from_expr(let1) @@ -410,7 +404,7 @@ def test_where_simple(): # Create expected module data = relay.cast(relay.var("data", shape=[1, 20]), "float16") weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") - a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16") + a = relay.nn.dense(data, weight, units=20, out_dtype="float16") b = relay.where(data, a, a) expected_mod = tvm.IRModule.from_expr(b) expected_mod = InferType()(expected_mod)