diff --git a/qa/L0_model_update/instance_update_test.py b/qa/L0_model_update/instance_update_test.py index d3021b650c9..4c211830fca 100644 --- a/qa/L0_model_update/instance_update_test.py +++ b/qa/L0_model_update/instance_update_test.py @@ -34,6 +34,7 @@ from tritonclient.utils import InferenceServerException from models.model_init_del.util import (get_count, reset_count, set_delay, update_instance_group, + update_sequence_batching, update_model_file, enable_batching, disable_batching) @@ -43,9 +44,21 @@ class TestInstanceUpdate(unittest.TestCase): __model_name = "model_init_del" def setUp(self): - # Initialize client + self.__reset_model() self.__triton = grpcclient.InferenceServerClient("localhost:8001") + def __reset_model(self): + # Reset counters + reset_count("initialize") + reset_count("finalize") + # Reset batching + disable_batching() + # Reset delays + set_delay("initialize", 0) + set_delay("infer", 0) + # Reset sequence batching + update_sequence_batching("") + def __get_inputs(self, batching=False): self.assertIsInstance(batching, bool) if batching: @@ -85,14 +98,8 @@ def __check_count(self, kind, expected_count, poll=False): self.assertEqual(get_count(kind), expected_count) def __load_model(self, instance_count, instance_config="", batching=False): - # Reset counters - reset_count("initialize") - reset_count("finalize") # Set batching enable_batching() if batching else disable_batching() - # Reset delays - set_delay("initialize", 0) - set_delay("infer", 0) # Load model self.__update_instance_count(instance_count, 0, @@ -143,6 +150,7 @@ def test_add_rm_add_instance(self): self.__update_instance_count(1, 0, batching=batching) # add stop() self.__unload_model(batching=batching) + self.__reset_model() # for next iteration # Test remove -> add -> remove an instance def test_rm_add_rm_instance(self): @@ -154,6 +162,7 @@ def test_rm_add_rm_instance(self): self.__update_instance_count(0, 1, batching=batching) # remove stop() self.__unload_model(batching=batching) + self.__reset_model() # for next iteration # Test reduce instance count to zero def test_rm_instance_to_zero(self): @@ -341,15 +350,89 @@ def infer(): # Unload model self.__unload_model() - # Test for instance update on direct sequence scheduling - @unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]") - def test_instance_update_on_direct_sequence_scheduling(self): - pass - - # Test for instance update on oldest sequence scheduling - @unittest.skip("Sequence will not continue after update [FIXME: DLIS-4820]") - def test_instance_update_on_oldest_sequence_scheduling(self): - pass + # Test wait for in-flight sequence completion and block new sequence + def test_sequence_instance_update(self): + for sequence_batching_type in [ + "direct { }\nmax_sequence_idle_microseconds: 10000000", + "oldest { max_candidate_sequences: 4 }\nmax_sequence_idle_microseconds: 10000000" + ]: + # Load model + update_instance_group("{\ncount: 2\nkind: KIND_CPU\n}") + update_sequence_batching(sequence_batching_type) + self.__triton.load_model(self.__model_name) + self.__check_count("initialize", 2) + self.__check_count("finalize", 0) + # Basic sequence inference + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=1, + sequence_start=True) + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=1) + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=1, + sequence_end=True) + # Update instance + update_instance_group("{\ncount: 4\nkind: KIND_CPU\n}") + self.__triton.load_model(self.__model_name) + self.__check_count("initialize", 4) + self.__check_count("finalize", 0) + # Start an in-flight sequence + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=1, + sequence_start=True) + # Check update instance will wait for in-flight sequence completion + # and block new sequence from starting. + update_instance_group("{\ncount: 3\nkind: KIND_CPU\n}") + update_complete = [False] + def update(): + self.__triton.load_model(self.__model_name) + update_complete[0] = True + self.__check_count("initialize", 4) + self.__check_count("finalize", 1) + infer_complete = [False] + def infer(): + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=2, + sequence_start=True) + infer_complete[0] = True + with concurrent.futures.ThreadPoolExecutor() as pool: + # Update should wait until sequence 1 end + update_thread = pool.submit(update) + time.sleep(2) # make sure update has started + self.assertFalse(update_complete[0], + "Unexpected update completion") + # New sequence should wait until update complete + infer_thread = pool.submit(infer) + time.sleep(2) # make sure infer has started + self.assertFalse(infer_complete[0], + "Unexpected infer completion") + # End sequence 1 should unblock update + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=1, + sequence_end=True) + time.sleep(2) # make sure update has returned + self.assertTrue(update_complete[0], "Update possibly stuck") + update_thread.result() + # Update completion should unblock new sequence + time.sleep(2) # make sure infer has returned + self.assertTrue(infer_complete[0], "Infer possibly stuck") + infer_thread.result() + # End sequence 2 + self.__triton.infer(self.__model_name, + self.__get_inputs(), + sequence_id=2, + sequence_end=True) + # Unload model + self.__triton.unload_model(self.__model_name) + self.__check_count("initialize", 4) + self.__check_count("finalize", 4, True) + self.__reset_model() if __name__ == "__main__": diff --git a/qa/python_models/model_init_del/config.pbtxt b/qa/python_models/model_init_del/config.pbtxt index ee0ed17d26b..be66468a0a3 100644 --- a/qa/python_models/model_init_del/config.pbtxt +++ b/qa/python_models/model_init_del/config.pbtxt @@ -49,4 +49,4 @@ instance_group [ count: 1 kind: KIND_CPU } -] +] # end instance_group diff --git a/qa/python_models/model_init_del/util.py b/qa/python_models/model_init_del/util.py index 6b77dde8066..0cc40aa9b7a 100644 --- a/qa/python_models/model_init_del/util.py +++ b/qa/python_models/model_init_del/util.py @@ -113,10 +113,31 @@ def update_instance_group(instance_group_str): full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt") with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f: txt = f.read() - txt = txt.split("instance_group [")[0] + txt, post_match = txt.split("instance_group [") txt += "instance_group [\n" txt += instance_group_str - txt += "\n]\n" + txt += "\n] # end instance_group\n" + txt += post_match.split("\n] # end instance_group\n")[1] + f.truncate(0) + f.seek(0) + f.write(txt) + return txt + +def update_sequence_batching(sequence_batching_str): + full_path = os.path.join(os.path.dirname(__file__), "config.pbtxt") + with open(full_path, mode="r+", encoding="utf-8", errors="strict") as f: + txt = f.read() + if "sequence_batching {" in txt: + txt, post_match = txt.split("sequence_batching {") + if sequence_batching_str != "": + txt += "sequence_batching {\n" + txt += sequence_batching_str + txt += "\n} # end sequence_batching\n" + txt += post_match.split("\n} # end sequence_batching\n")[1] + elif sequence_batching_str != "": + txt += "\nsequence_batching {\n" + txt += sequence_batching_str + txt += "\n} # end sequence_batching\n" f.truncate(0) f.seek(0) f.write(txt)