Skip to content

Commit

Permalink
Entrypoint download from url (#628)
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis authored Apr 14, 2023
1 parent 7271063 commit edc25da
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
9 changes: 9 additions & 0 deletions serving/docker/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def upload_checkpoints_to_s3(self):
subprocess.run(commands)
shutil.rmtree(self.properties["save_mp_checkpoint_path"])

def cleanup(self):
"""
Cleans up the downloaded files in tmp.
"""
if self.properties_manager.entry_point_url:
entrypoint_dir = Path(self.properties['entryPoint']).parent
shutil.rmtree(entrypoint_dir)

def run_partition(self):
commands = get_partition_cmd(self.properties_manager.is_mpi_mode,
self.properties)
Expand All @@ -161,6 +169,7 @@ def run_partition(self):
self.properties_manager.generate_properties_file()
self.copy_config_files()
self.upload_checkpoints_to_s3()
self.cleanup()
else:
raise Exception("Partitioning was not successful.")

Expand Down
30 changes: 26 additions & 4 deletions serving/docker/partition/properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
import glob
import torch
import requests

# Properties to exclude while generating serving.properties
from utils import is_engine_mpi_mode, get_engine_configs, get_download_dir
Expand All @@ -23,12 +25,15 @@

PARTITION_SUPPORTED_ENGINES = ['DeepSpeed', 'FasterTransformer']

CHUNK_SIZE = 4096 # 4MB chunk size


class PropertiesManager(object):

def __init__(self, properties_dir):
self.properties = {}
self.properties_dir = properties_dir
self.entry_point_url = None

self.load_properties()

Expand Down Expand Up @@ -93,9 +98,12 @@ def generate_properties_file(self):

for key, value in self.properties.items():
if key not in EXCLUDE_PROPERTIES:
if key == "entryPoint" and self.properties.get(
"entryPoint") == "model.py":
continue
if key == "entryPoint":
entry_point = self.properties.get("entryPoint")
if entry_point == "model.py":
continue
elif self.entry_point_url:
configs[f'option.{key}'] = self.entry_point_url
else:
configs[f'option.{key}'] = value

Expand Down Expand Up @@ -127,7 +135,8 @@ def validate_tp_degree(self):
)

def set_and_validate_entry_point(self):
if "entryPoint" not in self.properties:
entry_point = self.properties.get('entryPoint')
if entry_point is None:
entry_point = os.environ.get("DJL_ENTRY_POINT")
if entry_point is None:
entry_point_file = glob.glob(
Expand All @@ -145,6 +154,19 @@ def set_and_validate_entry_point(self):
f"model.py not found in model path {self.properties_dir}"
)
self.properties['entryPoint'] = entry_point
elif entry_point.lower().startswith('http'):
logging.info(f'Downloading entrypoint file.')
self.entry_point_url = self.properties['entryPoint']
download_dir = get_download_dir(self.properties_dir,
suffix='modelfile')
model_file = os.path.join(download_dir, 'model.py')
with requests.get(self.properties['entryPoint'], stream=True) as r:
with open(model_file, 'wb') as f:
for chunk in r.iter_content(chunk_size=CHUNK_SIZE):
if chunk:
f.write(chunk)
self.properties['entryPoint'] = model_file
logging.info(f'Entrypoint file downloaded successfully')

def set_and_validate_save_mp_checkpoint_path(self):
save_mp_checkpoint_path = self.properties["save_mp_checkpoint_path"]
Expand Down
4 changes: 2 additions & 2 deletions serving/docker/partition/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import json
import glob
import tempfile
import zipfile
import tempfile

MASTER_ADDR = "127.0.0.1"
MASTER_PORT = 29761
Expand Down Expand Up @@ -84,7 +84,7 @@ def is_engine_mpi_mode(engine):


def get_download_dir(properties_dir, suffix=""):
tmp = tempfile.mktemp(suffix=suffix, prefix="download")
tmp = tempfile.mkdtemp(suffix=suffix, prefix="download")
download_dir = os.environ.get("SERVING_DOWNLOAD_DIR", tmp)
if download_dir == "default":
download_dir = properties_dir
Expand Down

0 comments on commit edc25da

Please sign in to comment.