diff --git a/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp/benchmark_common/run_benchmark.sh b/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp/benchmark_common/run_benchmark.sh index 15ece4340799..215762d5558e 100644 --- a/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp/benchmark_common/run_benchmark.sh +++ b/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp/benchmark_common/run_benchmark.sh @@ -35,6 +35,18 @@ function _set_params(){ sharding_degree=${10:-"1"} # (可选) sharding_stage=${11:-"1"} # (可选)sharding case level=${12:-"o1"} # o1|o2|o3 + + if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then + if [ ${level} == "o3" ]; then + level="o2" + echo "amp level changed to o2 in pir mode" + else + echo "amp level is o3" + fi + else + echo "FLAGS_enable_pir_api = 0" + fi + local_batch_size=${13:-"8"} # (可选)本地batch size schedule_mode=${14:-"1F1B"} # (可选)schedule mode base_batch_size=$global_batch_size diff --git a/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp_pir/benchmark_common/run_benchmark.sh b/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp_pir/benchmark_common/run_benchmark.sh index 720ac66d1aba..558e8232ba8a 100644 --- a/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp_pir/benchmark_common/run_benchmark.sh +++ b/legacy/model_zoo/gpt-3/benchmarks/test_tipc/gpt/static/new_exec_pp_pir/benchmark_common/run_benchmark.sh @@ -35,6 +35,18 @@ function _set_params(){ sharding_degree=${10:-"1"} # (可选) sharding_stage=${11:-"1"} # (可选)sharding case level=${12:-"o1"} # o1|o2|o3 + + if [[ $FLAGS_enable_pir_api == "1" || $FLAGS_enable_pir_api == "True" ]]; then + if [ ${level} == "o3" ]; then + level="o2" + echo "amp level changed to o2 in pir mode" + else + echo "amp level is o3" + fi + else + echo "FLAGS_enable_pir_api = 0" + fi + local_batch_size=${13:-"8"} # (可选)本地batch size schedule_mode=${14:-"1F1B"} # (可选)schedule mode base_batch_size=$global_batch_size