Skip to content

Commit

Permalink
Add test for reload unmodified version with new version model file added
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Aug 14, 2024
1 parent ee04663 commit aabe4ca
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 87 deletions.
77 changes: 27 additions & 50 deletions qa/L0_lifecycle/lifecycle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3497,78 +3497,55 @@ def test_load_new_model_version(self):
model_name = "identity_fp32"
client = self._get_client(use_grpc=True)

self.assertTrue(client.is_model_ready(model_name, "1"))
self.assertFalse(client.is_model_ready(model_name, "2"))

with open(os.environ["SERVER_LOG"]) as f:
server_log = f.read()
self.assertEqual(server_log.count("[PB model] Loading version 1"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 0)
self.assertEqual(server_log.count("successfully loaded 'identity_fp32'"), 1)

# Update model config to also load version 2
config_path = os.path.join("models", model_name, "config.pbtxt")
with open(config_path, mode="r+", encoding="utf-8", errors="strict") as f:
config = f.read()
config = config.replace(
"version_policy: { specific: { versions: [1] } }",
"version_policy: { specific: { versions: [1, 2] } }",
)
f.truncate(0)
f.seek(0)
f.write(config)
# Reload the model and version 1 should not be reloaded
client.load_model(model_name)

# version 1 and 2 are already loaded
# version 3 is in the model directory but not loaded
# version 4 does not exist anywhere
self.assertTrue(client.is_model_ready(model_name, "1"))
self.assertTrue(client.is_model_ready(model_name, "2"))

with open(os.environ["SERVER_LOG"]) as f:
server_log = f.read()
self.assertEqual(
server_log.count("[PB model] Loading version 1"),
1,
"version 1 should not be reloaded",
)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 1)
self.assertEqual(server_log.count("successfully loaded 'identity_fp32'"), 2)

def test_update_loaded_version_and_load_new_version(self):
model_name = "identity_fp32"
client = self._get_client(use_grpc=True)

self.assertTrue(client.is_model_ready(model_name, "1"))
self.assertFalse(client.is_model_ready(model_name, "2"))

self.assertFalse(client.is_model_ready(model_name, "3"))
self.assertFalse(client.is_model_ready(model_name, "4"))
with open(os.environ["SERVER_LOG"]) as f:
server_log = f.read()
self.assertEqual(server_log.count("[PB model] Loading version 1"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 0)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 3"), 0)
self.assertEqual(server_log.count("[PB model] Loading version 4"), 0)
self.assertEqual(server_log.count("successfully loaded 'identity_fp32'"), 1)

# Update model file of version 1
Path(os.path.join("models", model_name, "1", "model.py")).touch()
# Update model config to also load version 2
# update version 2 model file
Path(os.path.join("models", model_name, "2", "model.py")).touch()
# add version 4 model file
src_path = os.path.join("models", model_name, "3")
dst_path = os.path.join("models", model_name, "4")
shutil.copytree(src_path, dst_path)
# update model config to load version 1 to 4
config_path = os.path.join("models", model_name, "config.pbtxt")
with open(config_path, mode="r+", encoding="utf-8", errors="strict") as f:
config = f.read()
config = config.replace(
"version_policy: { specific: { versions: [1] } }",
"version_policy: { specific: { versions: [1, 2] } }",
"version_policy: { specific: { versions: [1, 2, 3, 4] } }",
)
f.truncate(0)
f.seek(0)
f.write(config)
# Reload the model and version 1 should be reloaded
# reload the model
client.load_model(model_name)

# version 1 is unmodified so it should not be reloaded
# version 2 is modified so it should be reloaded
# version 3 model file existed but not loaded so it should be loaded
# version 4 is a new version so it should be loaded
self.assertTrue(client.is_model_ready(model_name, "1"))
self.assertTrue(client.is_model_ready(model_name, "2"))

self.assertTrue(client.is_model_ready(model_name, "3"))
self.assertTrue(client.is_model_ready(model_name, "4"))
with open(os.environ["SERVER_LOG"]) as f:
server_log = f.read()
self.assertEqual(server_log.count("[PB model] Loading version 1"), 2)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 1"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 2"), 2)
self.assertEqual(server_log.count("[PB model] Loading version 3"), 1)
self.assertEqual(server_log.count("[PB model] Loading version 4"), 1)
self.assertEqual(server_log.count("successfully loaded 'identity_fp32'"), 2)


Expand Down
40 changes: 3 additions & 37 deletions qa/L0_lifecycle/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2202,11 +2202,12 @@ LOG_IDX=$((LOG_IDX+1))
rm -rf models
mkdir models
cp -r ../python_models/identity_fp32 models/ && (cd models/identity_fp32 && \
echo "version_policy: { specific: { versions: [1] } }" >> config.pbtxt && \
echo "version_policy: { specific: { versions: [1, 2] } }" >> config.pbtxt && \
echo " def initialize(self, args):" >> model.py && \
echo " pb_utils.Logger.log_info(f'[PB model] Loading version {args[\"model_version\"]}')" >> model.py && \
mkdir 1 && cp model.py 1 && \
mkdir 2 && mv model.py 2)
mkdir 2 && cp model.py 2 && \
mkdir 3 && mv model.py 3)

export PYTHONDONTWRITEBYTECODE="True"
SERVER_ARGS="--model-repository=`pwd`/models --model-control-mode=explicit --load-model=*"
Expand All @@ -2231,41 +2232,6 @@ kill $SERVER_PID
wait $SERVER_PID
unset PYTHONDONTWRITEBYTECODE

LOG_IDX=$((LOG_IDX+1))

# LifeCycleTest.test_update_loaded_version_and_load_new_version
rm -rf models
mkdir models
cp -r ../python_models/identity_fp32 models/ && (cd models/identity_fp32 && \
echo "version_policy: { specific: { versions: [1] } }" >> config.pbtxt && \
echo " def initialize(self, args):" >> model.py && \
echo " pb_utils.Logger.log_info(f'[PB model] Loading version {args[\"model_version\"]}')" >> model.py && \
mkdir 1 && cp model.py 1 && \
mkdir 2 && mv model.py 2)

export PYTHONDONTWRITEBYTECODE="True"
SERVER_ARGS="--model-repository=`pwd`/models --model-control-mode=explicit --load-model=*"
SERVER_LOG="./inference_server_$LOG_IDX.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e
SERVER_LOG=$SERVER_LOG python $LC_TEST LifeCycleTest.test_update_loaded_version_and_load_new_version >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
fi
set -e

kill $SERVER_PID
wait $SERVER_PID
unset PYTHONDONTWRITEBYTECODE

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
Expand Down

0 comments on commit aabe4ca

Please sign in to comment.