diff --git a/serving/docker/convert/.gitignore b/serving/docker/convert/.gitignore new file mode 100644 index 0000000000..a96c5d3f5d --- /dev/null +++ b/serving/docker/convert/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +model/ +tmp/ +models.json diff --git a/serving/docker/convert/arg_parser.py b/serving/docker/convert/arg_parser.py new file mode 100644 index 0000000000..4561cbc02f --- /dev/null +++ b/serving/docker/convert/arg_parser.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +import argparse +import os + + +def converter_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-l", + "--limit", + type=int, + default=1, + help="Max amount of models to convert") + parser.add_argument("-o", "--output-dir", help="Model output directory") + parser.add_argument("-f", + "--output-format", + default="PyTorch", + choices=["PyTorch", "OnnxRuntime", "Rust"], + help="Model output format") + parser.add_argument("-r", + "--retry-failed", + action='store_true', + help="Retry failed model") + parser.add_argument("-u", + "--cpu-only", + action='store_true', + help="Only validate jit traced model on CPU") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "-c", + "--category", + help="Model category to convert", + ) + group.add_argument("-m", "--model-name", help="Model name to convert") + + args = parser.parse_args() + if args.output_dir is None: + args.output_dir = "." + + if not os.path.exists(args.output_dir): + raise ValueError(f"Invalid output directory: {args.output_dir}.") + + return args diff --git a/serving/docker/convert/fill_mask_converter.py b/serving/docker/convert/fill_mask_converter.py new file mode 100644 index 0000000000..ff9de4bea2 --- /dev/null +++ b/serving/docker/convert/fill_mask_converter.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging + +import torch + +from huggingface_converter import HuggingfaceConverter + + +class FillMaskConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "fill-mask" + self.application = "nlp/fill_mask" + self.translator = "ai.djl.huggingface.translator.FillMaskTranslatorFactory" + self.inputs = "Hello I'm a [MASK] model." + self.outputs = [ + "fashion", "role", 'new', 'super', 'fine', 'male', 'female', 'big', + 'top', 'modeling', 'virtual' + ] + + def verify_jit_output(self, hf_pipeline, encoding, out): + tokenizer = hf_pipeline.tokenizer + mask_token_id = tokenizer.mask_token_id + mask = encoding["input_ids"].squeeze(0) == mask_token_id + + mask_index = torch.nonzero(mask, as_tuple=False).squeeze(0) + logits = out['logits'][0, mask_index] + answer = torch.argmax(logits) + prediction = tokenizer.decode(answer).strip() + + if prediction not in self.outputs: + text = self.inputs + if tokenizer.mask_token != "[MASK]": + text = text.replace("[MASK]", tokenizer.mask_token) + pipeline_output = hf_pipeline(text) + + if prediction not in [o["token_str"] for o in pipeline_output]: + logging.error(f"Unexpected inference result: {prediction}") + return False, "Unexpected inference result" + + logging.warning( + f"pipeline output differs from expected: {pipeline_output}") + + return True, None + + def encode_inputs(self, tokenizer): + text = self.inputs.replace("[MASK]", tokenizer.mask_token) + return tokenizer.encode_plus(text, return_tensors='pt') + + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: + return {"maskToken": hf_pipeline.tokenizer.mask_token} diff --git a/serving/docker/convert/huggingface_converter.py b/serving/docker/convert/huggingface_converter.py new file mode 100644 index 0000000000..167b4ad810 --- /dev/null +++ b/serving/docker/convert/huggingface_converter.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging +import os.path +import shutil +import sys +from argparse import Namespace + +import onnx +import safetensors_convert +import torch +from huggingface_hub import hf_hub_download, HfApi +from transformers import pipeline, AutoTokenizer, AutoConfig + +from metadata import HuggingfaceMetadata +from shasum import sha1_sum +from zip_utils import zip_dir + + +class PipelineHolder(object): + + def __init__(self, tokenizer, model): + self.tokenizer = tokenizer + self.model = model + + +class ModelHolder(object): + + def __init__(self, config): + self.config = config + + +class HuggingfaceConverter: + + def __init__(self): + self.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu") + self.task = None + self.application = None + self.translator = None + self.inputs = None + self.outputs = None + self.api = HfApi() + + def save_model(self, model_info, args: Namespace, temp_dir: str): + if args.output_format == "OnnxRuntime": + return self.save_onnx_model(model_info, args, temp_dir) + elif args.output_format == "Rust": + return self.save_rust_model(model_info, args, temp_dir) + else: + return self.save_pytorch_model(model_info, args, temp_dir) + + def save_onnx_model(self, model_info, args: Namespace, temp_dir: str): + model_id = model_info.modelId + + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + logging.info(f"Saving onnxruntime model: {model_id} ...") + + from optimum.commands.optimum_cli import main + + sys.argv = [ + "model_zoo_importer.py", "export", "onnx", "-m", model_id, temp_dir + ] + main() + + model = onnx.load_model(os.path.join(temp_dir, "model.onnx"), + load_external_data=False) + inputs = repr(model.graph.input) + include_types = "token_type_id" in inputs + + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id) + hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config)) + size = self.save_to_model_zoo(model_info, args.output_dir, + "OnnxRuntime", temp_dir, hf_pipeline, + include_types) + + return True, None, size + + def save_rust_model(self, model_info, args: Namespace, temp_dir: str): + model_id = model_info.modelId + + config = AutoConfig.from_pretrained(model_id) + if hasattr(config, "model_type"): + if config.model_type == "bert": + include_types = True + elif config.model_type == "distilbert": + include_types = False + else: + return False, f"Unsupported model_type: {config.model_type}", -1 + else: + return False, f"Unknown model_type: {model_id}", -1 + + logging.info(f"Saving rust model: {model_id} ...") + + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + hf_pipeline = PipelineHolder(tokenizer, ModelHolder(config)) + try: + # Save tokenizer.json to temp dir + self.save_tokenizer(hf_pipeline, temp_dir) + except Exception as e: + logging.warning(f"Failed to save tokenizer: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to save tokenizer", -1 + + # Save config.json + config_file = hf_hub_download(repo_id=model_id, filename="config.json") + shutil.copyfile(config_file, os.path.join(temp_dir, "config.json")) + + target = os.path.join(temp_dir, "model.safetensors") + model = self.api.model_info(model_id, files_metadata=True) + has_sf_file = False + has_pt_file = False + for sibling in model.siblings: + if sibling.rfilename == "model.safetensors": + has_sf_file = True + elif sibling.rfilename == "pytorch_model.bin": + has_pt_file = True + + if has_sf_file: + file = hf_hub_download(repo_id=model_id, + filename="model.safetensors") + shutil.copyfile(file, target) + elif has_pt_file: + file = hf_hub_download(repo_id=model_id, + filename="pytorch_model.bin") + safetensors_convert.convert_file(file, target) + else: + return False, f"No model file found for: {model_id}", -1 + + size = self.save_to_model_zoo(model_info, args.output_dir, "Rust", + temp_dir, hf_pipeline, include_types) + + return True, None, size + + def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str): + model_id = model_info.modelId + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + try: + hf_pipeline = self.load_model(model_id) + except Exception as e: + logging.warning(f"Failed to load model: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to load model", -1 + + try: + # Save tokenizer.json to temp dir + self.save_tokenizer(hf_pipeline, temp_dir) + except Exception as e: + logging.warning(f"Failed to save tokenizer: {model_id}.") + logging.warning(e, exc_info=True) + return False, "Failed to save tokenizer", -1 + + # Save config.json just for reference + config = hf_hub_download(repo_id=model_id, filename="config.json") + shutil.copyfile(config, os.path.join(temp_dir, "config.json")) + + # Save jit traced .pt file to temp dir + include_types = False + model_file = self.jit_trace_model(hf_pipeline, model_id, temp_dir, + include_types) + if not model_file: + return False, "Failed to trace model", -1 + + result, reason = self.verify_jit_model(hf_pipeline, model_file, + include_types, args.cpu_only) + if not result: + include_types = True + model_file = self.jit_trace_model(hf_pipeline, model_id, temp_dir, + include_types) + if not model_file: + return False, reason, -1 + + result, reason = self.verify_jit_model(hf_pipeline, model_file, + include_types, + args.cpu_only) + if not result: + return False, reason, -1 + + size = self.save_to_model_zoo(model_info, args.output_dir, "PyTorch", + temp_dir, hf_pipeline, include_types) + + return True, None, size + + @staticmethod + def save_tokenizer(hf_pipeline, temp_dir: str): + hf_pipeline.tokenizer.save_pretrained(temp_dir) + if not os.path.exists(os.path.join(temp_dir, "tokenizer.json")): + raise ValueError("no fast tokenizer found.") + + # only keep tokenizer.json file + for path in os.listdir(temp_dir): + if path != "tokenizer.json" and path != "tokenizer_config.json": + os.remove(os.path.join(temp_dir, path)) + + def jit_trace_model(self, hf_pipeline, model_id: str, temp_dir: str, + include_types: bool): + logging.info( + f"Tracing model: {model_id} include_token_types={include_types} ..." + ) + encoding = self.encode_inputs(hf_pipeline.tokenizer) + input_ids = encoding["input_ids"] + attention_mask = encoding["attention_mask"] + token_type_ids = encoding.get("token_type_ids") + if include_types and token_type_ids is None: + return None + + # noinspection PyBroadException + try: + if include_types: + script_module = torch.jit.trace( + hf_pipeline.model, + (input_ids, attention_mask, token_type_ids), + strict=False) + else: + script_module = torch.jit.trace(hf_pipeline.model, + (input_ids, attention_mask), + strict=False) + + model_name = model_id.split("/")[-1] + logging.info(f"Saving torchscript model: {model_name}.pt ...") + model_file = os.path.join(temp_dir, f"{model_name}.pt") + script_module.save(model_file) + except Exception as e: + logging.warning(f"Failed to trace model: {model_id}.") + logging.warning(e, exc_info=True) + return None + + return model_file + + def save_to_model_zoo(self, model_info, output_dir: str, engine: str, + temp_dir: str, hf_pipeline, include_types: bool): + model_id = model_info.modelId + model_name = model_id.split("/")[-1] + group_id = f"ai/djl/huggingface/{engine.lower()}" + repo_dir = f"{output_dir}/model/{self.application}/{group_id}/{model_id}" + model_dir = f"{repo_dir}/0.0.1" + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + # Save serving.properties + serving_file = os.path.join(temp_dir, "serving.properties") + arguments = self.get_extra_arguments(hf_pipeline, model_id, temp_dir) + if include_types: + arguments["includeTokenTypes"] = "true" + arguments["translatorFactory"] = self.translator + + with open(serving_file, 'w') as f: + f.write(f"engine={engine}\n" + f"option.modelName={model_name}\n") + if engine == "PyTorch": + f.write(f"option.mapLocation=true\n") + + for k, v in arguments.items(): + f.write(f"{k}={v}\n") + + # Save model as .zip file + logging.info(f"Saving DJL model as zip: {model_name}.zip ...") + zip_file = os.path.join(model_dir, f"{model_name}.zip") + zip_dir(temp_dir, zip_file) + + # Save metadata.json + arguments["engine"] = engine + sha1 = sha1_sum(zip_file) + file_size = os.path.getsize(zip_file) + metadata = HuggingfaceMetadata(model_info, engine, self.application, + sha1, file_size, arguments) + metadata_file = os.path.join(repo_dir, "metadata.json") + metadata.save_metadata(metadata_file) + + return file_size + + def verify_jit_model(self, hf_pipeline, model_file: str, + include_types: bool, cpu_only: bool): + logging.info( + f"Verifying torchscript model(include_token_types={include_types}): {model_file} ..." + ) + + tokenizer = hf_pipeline.tokenizer + encoding = self.encode_inputs(tokenizer) + + input_ids = encoding["input_ids"] + attention_mask = encoding["attention_mask"] + token_type_ids = encoding.get("token_type_ids") + if torch.cuda.is_available() and not cpu_only: + traced_model = torch.jit.load(model_file, map_location='cuda:0') + traced_model.to(self.device) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + if token_type_ids is not None: + token_type_ids = token_type_ids.to(self.device) + else: + traced_model = torch.jit.load(model_file) + + traced_model.eval() + + try: + # test traced model + if include_types: + out = traced_model(input_ids, attention_mask, token_type_ids) + else: + out = traced_model(input_ids, attention_mask) + except RuntimeError as e: + logging.warning(e, exc_info=True) + return False, "Failed to run inference on jit model" + + return self.verify_jit_output(hf_pipeline, encoding, out) + + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: + return {} + + def verify_jit_output(self, hf_pipeline, encoding, out): + if not hasattr(out, "last_hidden_layer"): + return False, f"Unexpected inference result: {out}" + + return True, None + + def load_model(self, model_id: str): + logging.info(f"Loading model: {model_id} ...") + kwargs = { + "tokenizer": model_id, + "device": -1 # always use CPU to trace the model + } + return pipeline(task=self.task, + model=model_id, + framework="pt", + **kwargs) + + def encode_inputs(self, tokenizer): + return tokenizer.encode_plus(self.inputs, return_tensors='pt') diff --git a/serving/docker/convert/huggingface_importer.py b/serving/docker/convert/huggingface_importer.py new file mode 100644 index 0000000000..de7599a5a1 --- /dev/null +++ b/serving/docker/convert/huggingface_importer.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging +import os.path +import shutil +import sys + +from arg_parser import converter_args +from fill_mask_converter import FillMaskConverter +from huggingface_models import HuggingfaceModels +from question_answering_converter import QuestionAnsweringConverter +from sentence_similarity_converter import SentenceSimilarityConverter +from text_classification_converter import TextClassificationConverter +from token_classification_converter import TokenClassificationConverter + +SUPPORTED_TASK = { + "fill-mask": FillMaskConverter(), + "question-answering": QuestionAnsweringConverter(), + "sentence-similarity": SentenceSimilarityConverter(), + "text-classification": TextClassificationConverter(), + "token-classification": TokenClassificationConverter(), +} + + +def main(): + logging.basicConfig(stream=sys.stdout, + format="%(message)s", + level=logging.INFO) + args = converter_args() + + huggingface_models = HuggingfaceModels(args.output_dir) + temp_dir = f"{args.output_dir}/tmp" + + models = huggingface_models.list_models(args) + if not models: + logging.warning(f"model not found: {args}") + + for model in models: + task = model["task"] + model_info = model["model_info"] + converter = SUPPORTED_TASK[task] + + try: + result, reason, size = converter.save_model( + model_info, args, temp_dir) + if not result: + logging.error(f"{model_info.modelId}: {reason}") + except Exception as e: + logging.warning(f"Failed to convert model: {model_info.modelId}.") + logging.warning(e, exc_info=True) + result = False + reason = "Failed to convert model" + size = -1 + + huggingface_models.update_progress(model_info, converter.application, + result, reason, size, args.cpu_only) + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + logging.info("finished.") + + +if __name__ == "__main__": + main() diff --git a/serving/docker/convert/huggingface_models.py b/serving/docker/convert/huggingface_models.py new file mode 100644 index 0000000000..cba19d477e --- /dev/null +++ b/serving/docker/convert/huggingface_models.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import json +import logging +import os +from argparse import Namespace +from typing import List + +from huggingface_hub import HfApi +from huggingface_hub import hf_hub_download +from huggingface_hub.hf_api import ModelInfo + +ARCHITECTURES_2_TASK = { + "ForQuestionAnswering": "question-answering", + "ForTokenClassification": "token-classification", + "ForSequenceClassification": "text-classification", + "ForMultipleChoice": "text-classification", + "ForMaskedLM": "fill-mask", +} +LANGUAGES = HfApi().get_model_tags()["language"] + + +def get_lang_tags(model_info): + tags = {} + for tag in model_info.tags: + if tag in LANGUAGES: + tags[tag] = "true" + + if not tags: + tags["en"] = "true" + + return tags + + +class HuggingfaceModels: + + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.processed_models = {} + + output_path = os.path.join(output_dir, "models.json") + if os.path.exists(output_path): + with open(output_path, "r") as f: + self.processed_models = json.load(f) + + self.temp_dir = f"{self.output_dir}/tmp" + + def list_models(self, args: Namespace) -> List[dict]: + import_all = os.environ.get("HF_IMPORT_ALL") + + api = HfApi() + if args.model_name: + all_models = api.list_models(search=args.model_name, + sort="downloads", + direction=-1, + limit=args.limit) + import_all = True + else: + all_models = api.list_models(filter=args.category, + sort="downloads", + direction=-1, + limit=args.limit) + models = [ + model for model in all_models + if 'pytorch' in model.tags or 'safetensors' in model.tags + ] + if not models: + if args.model_name: + logging.warning(f"no model found: {args.model_name}.") + else: + logging.warning(f"no model matches category: {args.category}.") + + return [] + + ret = [] + for model_info in models: + model_id = model_info.modelId + + # flair model is not supported yet + if "flair" in model_info.tags: + logging.info(f"Skip flair model: {model_id}.") + continue + + languages = get_lang_tags(model_info) + if "en" not in languages and not import_all: + logging.warning(f"Skip non-English model: {model_id}.") + continue + + existing_model = self.processed_models.get(model_id) + if existing_model: + existing_model["downloads"] = model_info.downloads + if not args.retry_failed or existing_model[ + "result"] == "success": + logging.info(f"Skip converted model: {model_id}.") + continue + + if model_info.downloads < 50 and not import_all: + logging.info( + f"Skip model {model_info.modelId}, downloads {model_info.downloads} < 50" + ) + continue + + try: + config = hf_hub_download(repo_id=model_id, + filename="config.json") + except EnvironmentError: + logging.info(f"Skip {model_id}, no config.json found.") + continue + + with open(config) as f: + config = json.load(f) + + task, architecture = self.to_supported_task(config) + if not task: + if "sentence-similarity" in model_info.tags: + task = "sentence-similarity" + + if not task: + logging.info( + f"Unsupported model architecture: {architecture} for {model_id}." + ) + continue + + if args.category and args.category != task: + logging.info( + f"Skip {model_id}, expect task: {args.category}, detected {task}." + ) + continue + + model = { + "model_info": model_info, + "config": config, + "task": task, + } + ret.append(model) + + return ret + + def update_progress(self, model_info: ModelInfo, application: str, + result: bool, reason: str, size: int, cpu_only: bool): + status = { + "result": "success" if result else "failed", + "application": application, + "sha1": model_info.sha, + "size": size, + "downloads": model_info.downloads, + } + if reason: + status["reason"] = reason + if cpu_only: + status["cpu_only"] = True + + self.processed_models[model_info.modelId] = status + + dict_file = os.path.join(self.output_dir, "models.json") + with open(dict_file, 'w') as f: + json.dump(self.processed_models, + f, + sort_keys=True, + indent=2, + ensure_ascii=False) + + @staticmethod + def to_supported_task(config: dict): + architectures = config.get("architectures") + if not architectures: + return None, "No architectures found" + + architecture = architectures[0] + for arch in ARCHITECTURES_2_TASK: + if architecture.endswith(arch): + return ARCHITECTURES_2_TASK[arch], architecture + + return None, architecture diff --git a/serving/docker/convert/metadata.py b/serving/docker/convert/metadata.py new file mode 100644 index 0000000000..7e835eb94d --- /dev/null +++ b/serving/docker/convert/metadata.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import json +from huggingface_models import get_lang_tags + + +class HuggingfaceMetadata: + + def __init__(self, model_info, engine: str, application: str, sha1: str, + file_size: int, arguments: dict): + self.model_info = model_info + self.group_id = f"ai.djl.huggingface.{engine.lower()}" + self.artifact_id = model_info.modelId + self.model_name = model_info.modelId.split("/")[-1] + self.application = application + self.sha1 = sha1 + self.file_size = file_size + self.arguments = arguments + self.options = {} + if engine == "PyTorch": + self.options["mapLocation"] = True + + def save_metadata(self, metadata_file: str): + properties = get_lang_tags(self.model_info) + + metadata = { + "metadataVersion": + "0.2", + "resourceType": + "model", + "application": + self.application, + "groupId": + self.group_id, + "artifactId": + self.artifact_id, + "name": + self.model_name, + "description": + f"Huggingface transformers model: {self.model_name}", + "website": + "http://www.djl.ai/extensions/tokenizers", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [{ + "version": "0.0.1", + "snapshot": False, + "name": self.model_name, + "properties": properties, + "arguments": self.arguments, + "options": self.options, + "files": { + "model": { + "uri": f"0.0.1/{self.model_name}.zip", + "name": "", + "sha1Hash": self.sha1, + "size": self.file_size + } + } + }] + } + with open(metadata_file, 'w') as f: + json.dump(metadata, + f, + sort_keys=False, + indent=2, + ensure_ascii=False) diff --git a/serving/docker/convert/question_answering_converter.py b/serving/docker/convert/question_answering_converter.py new file mode 100644 index 0000000000..935c211860 --- /dev/null +++ b/serving/docker/convert/question_answering_converter.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging + +import torch + +from huggingface_converter import HuggingfaceConverter + + +class QuestionAnsweringConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "question-answering" + self.application = "nlp/question_answer" + self.translator = "ai.djl.huggingface.translator.QuestionAnsweringTranslatorFactory" + self.inputs = { + "question": + "When did BBC Japan start broadcasting?", + "context": + "BBC Japan was a general entertainment Channel. Which operated between December 2004 and April 2006." + " It ceased operations after its Japanese distributor folded." + } + self.outputs = "december 2004" + + def verify_jit_output(self, hf_pipeline, encoding, out): + tokenizer = hf_pipeline.tokenizer + input_ids = encoding["input_ids"] + + start_ = out["start_logits"] + end_ = out["end_logits"] + start_[0, 0] = -10000 + end_[0, 0] = -10000 + answer_start = torch.argmax(start_) + answer_end = torch.argmax(end_) + 1 + + out_ids = input_ids[0].tolist()[answer_start:answer_end] + tokens = tokenizer.convert_ids_to_tokens(out_ids) + prediction = tokenizer.convert_tokens_to_string(tokens).strip() + + if prediction.lower() != self.outputs: + pipeline_output = hf_pipeline(self.inputs) + if pipeline_output["answer"].strip() != prediction: + logging.error(f"Unexpected inference result: {prediction}") + return False, "Unexpected inference result" + + logging.warning( + f"pipeline output differs from expected: {pipeline_output}") + + return True, None + + def encode_inputs(self, tokenizer): + text = self.inputs["question"] + text_pair = self.inputs["context"] + return tokenizer.encode_plus(text, + text_pair=text_pair, + return_tensors='pt') diff --git a/serving/docker/convert/requirements.txt b/serving/docker/convert/requirements.txt new file mode 100644 index 0000000000..28adc04880 --- /dev/null +++ b/serving/docker/convert/requirements.txt @@ -0,0 +1,6 @@ +huggingface_hub +transformers +torch +protobuf==3.20.2 +optimum[exporters,onnxruntime] +safetensors \ No newline at end of file diff --git a/serving/docker/convert/safetensors_convert.py b/serving/docker/convert/safetensors_convert.py new file mode 100644 index 0000000000..fd0a3cb04f --- /dev/null +++ b/serving/docker/convert/safetensors_convert.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import os +from collections import defaultdict +from typing import List, Dict + +import torch +from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_filename: str, sf_filename: str): + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dir_name = os.path.dirname(sf_filename) + os.makedirs(dir_name, exist_ok=True) + save_file(loaded, sf_filename, metadata=metadata) + check_file_size(sf_filename, pt_filename) + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def check_file_size(sf_filename: str, pt_filename: str): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) diff --git a/serving/docker/convert/sentence_similarity_converter.py b/serving/docker/convert/sentence_similarity_converter.py new file mode 100644 index 0000000000..7d28db57ea --- /dev/null +++ b/serving/docker/convert/sentence_similarity_converter.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import json +import logging +import os +import shutil + +import requests +import torch +from transformers import AutoTokenizer, AutoModel, AutoConfig + +from huggingface_converter import HuggingfaceConverter, PipelineHolder +from huggingface_hub import hf_hub_download + + +class SentenceSimilarityConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "sentence-similarity" + self.application = "nlp/text_embedding" + self.translator = "ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory" + self.inputs = "This is an example sentence" + self.outputs = 0 + + def load_model(self, model_id: str): + logging.info(f"Loading model: {model_id} ...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModel.from_pretrained(model_id) + + return PipelineHolder(tokenizer, model) + + def verify_jit_output(self, hf_pipeline, encoding, out): + last_hidden_state = out["last_hidden_state"].to("cpu") + + pipeline_output = hf_pipeline.model(**encoding) + expected = pipeline_output.last_hidden_state + if not torch.allclose( + expected, last_hidden_state, atol=1e-05, rtol=1e-03): + return False, "Unexpected inference result" + + return True, None + + def get_extra_arguments(self, hf_pipeline, model_id: str, + temp_dir: str) -> dict: + args = {"padding": "true"} + for config_name in [ + 'sentence_bert_config.json', 'sentence_roberta_config.json', + 'sentence_distilbert_config.json', + 'sentence_camembert_config.json', + 'sentence_albert_config.json', + 'sentence_xlm-roberta_config.json', + 'sentence_xlnet_config.json' + ]: + try: + file = hf_hub_download(repo_id=model_id, filename=config_name) + with open(file) as f: + config = json.load(f) + if config.get("max_seq_length"): + args["maxLength"] = config.get("max_seq_length") + if config.get("do_lower_case"): + args["doLowerCase"] = config.get("do_lower_case") + + break + except requests.exceptions.HTTPError: + pass + + if not "maxLength" in args: + if hasattr(hf_pipeline.model, "config"): + config = hf_pipeline.model.config + else: + config = AutoConfig.from_pretrained(model_id) + tokenizer = hf_pipeline.tokenizer + if hasattr(config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length"): + max_seq_length = min(config.max_position_embeddings, + tokenizer.model_max_length) + args["maxLength"] = str(max_seq_length) + + pooling_path = None + dense_path = None + layer_norm_path = None + normalize = False + try: + file = hf_hub_download(repo_id=model_id, filename="modules.json") + with open(file, "r") as f: + modules = json.load(f) + + for module in modules: + module_type = module.get("type") + if module_type == "sentence_transformers.models.Pooling": + pooling_path = module["path"] + elif module_type == "sentence_transformers.models.Dense": + dense_path = module["path"] + elif module_type == "sentence_transformers.models.LayerNorm": + layer_norm_path = module["path"] + elif module_type == "sentence_transformers.models.Normalize": + normalize = "true" + elif module_type != "sentence_transformers.models.Transformer": + logging.warning(f"Unexpected module: {module_type}.") + except requests.exceptions.HTTPError: + logging.warning(f"{model_id}: modules.json not found.") + + if pooling_path: + try: + file = hf_hub_download(repo_id=model_id, + filename=f"{pooling_path}/config.json") + if os.path.exists(file): + with open(file, "r") as f: + pooling = json.load(f) + if pooling.get("pooling_mode_cls_token"): + args["pooling"] = "cls" + elif pooling.get("pooling_mode_max_tokens"): + args["pooling"] = "max" + elif pooling.get("pooling_mode_mean_sqrt_len_tokens"): + args["pooling"] = "mean_sqrt_len" + elif pooling.get("pooling_mode_weightedmean_tokens"): + args["pooling"] = "weightedmean" + elif pooling.get("pooling_mode_lasttoken"): + args["pooling"] = "lasttoken" + except requests.exceptions.HTTPError: + logging.warning( + f"{model_id}: {pooling_path}/config.json not found.") + + if dense_path: + try: + file = hf_hub_download(repo_id=model_id, + filename=f"{dense_path}/config.json") + with open(file, "r") as f: + dense = json.load(f) + activation = dense.get("activation_function") + if activation == "torch.nn.modules.activation.Tanh": + args["denseActivation"] = "Tanh" + elif activation != "torch.nn.modules.linear.Identity": + logging.warning( + f"Unexpected activation function: {activation}.") + self.save_module_weight(model_id, temp_dir, dense_path, + "linear") + args["dense"] = "linear.safetensors" + except requests.exceptions.HTTPError: + logging.debug(f"{model_id}: {dense_path} not found.") + + if layer_norm_path: + try: + self.save_module_weight(model_id, temp_dir, layer_norm_path, + "norm") + args["layerNorm"] = "norm.safetensors" + except requests.exceptions.HTTPError: + logging.warning(f"{model_id}: {layer_norm_path} not found.") + + if not normalize: + args["normalize"] = "false" + + return args + + @staticmethod + def save_module_weight(model_id: str, temp_dir: str, layer: str, + name: str): + file = hf_hub_download(repo_id=model_id, + filename=f"{layer}/model.safetensors") + shutil.copyfile(file, os.path.join(temp_dir, f"{name}.safetensors")) diff --git a/serving/docker/convert/shasum.py b/serving/docker/convert/shasum.py new file mode 100644 index 0000000000..9441985305 --- /dev/null +++ b/serving/docker/convert/shasum.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import hashlib + + +def sha1_sum(file_name: str): + with open(file_name, 'rb') as f: + content = f.read() + return hashlib.sha1(content).hexdigest() diff --git a/serving/docker/convert/text_classification_converter.py b/serving/docker/convert/text_classification_converter.py new file mode 100644 index 0000000000..8b1dba4efa --- /dev/null +++ b/serving/docker/convert/text_classification_converter.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging +import math + +import torch + +from huggingface_converter import HuggingfaceConverter + + +class TextClassificationConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "text-classification" + self.application = "nlp/text_classification" + self.translator = "ai.djl.huggingface.translator.TextClassificationTranslatorFactory" + self.inputs = "DJL is the best." + self.outputs = None + + def verify_jit_output(self, hf_pipeline, encoding, out): + config = hf_pipeline.model.config + logits = out['logits'][0] + + if config.problem_type == "multi_label_classification" or config.num_labels == 1: + logits = torch.sigmoid(logits) + elif config.problem_type == "single_label_classification" or config.num_labels > 1: + logits = torch.softmax(logits, dim=0) + elif hasattr(config, "function_to_apply"): + logging.error( + f"Customized function not supported: {config.function_to_apply}" + ) + return False, "Customized function not supported" + + index = logits.argmax().item() + label = config.id2label[index] + score = logits[index] + pipeline_output = hf_pipeline(self.inputs) + + for item in pipeline_output: + if item["label"] == label: + if math.isclose(item["score"], score, abs_tol=1e-3): + return True, None + break + + return False, f"Unexpected inference result" diff --git a/serving/docker/convert/token_classification_converter.py b/serving/docker/convert/token_classification_converter.py new file mode 100644 index 0000000000..fead8d247b --- /dev/null +++ b/serving/docker/convert/token_classification_converter.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging + +import torch + +from huggingface_converter import HuggingfaceConverter + + +class TokenClassificationConverter(HuggingfaceConverter): + + def __init__(self): + super().__init__() + self.task = "token-classification" + self.application = "nlp/token_classification" + self.translator = "ai.djl.huggingface.translator.TokenClassificationTranslatorFactory" + self.inputs = "My name is Wolfgang and I live in Berlin" + self.outputs = ["wolfgang", "PER", "PROPN"] + + def verify_jit_output(self, hf_pipeline, encoding, out): + config = hf_pipeline.model.config + tokenizer = hf_pipeline.tokenizer + + logits = out["logits"][0].detach() + input_ids = encoding["input_ids"][0].tolist() + offset_mapping = encoding.encodings[0].offsets + special_token_masks = encoding.encodings[0].special_tokens_mask + probabilities = torch.softmax(logits, dim=1) + entities = [] + + for idx, scores in enumerate(probabilities): + if special_token_masks[idx]: + continue + + entity_idx = scores.argmax().item() + entity = config.id2label[entity_idx] + + if entity != "O": + item = { + "entity": entity, + "score": scores[entity_idx], + "index": idx, + "word": tokenizer.convert_ids_to_tokens(input_ids[idx]), + "start": offset_mapping[idx][0], + "end": offset_mapping[idx][1], + } + entities.append(item) + + if self.outputs[0] in item["word"].lower() and ( + self.outputs[1] in entity + or self.outputs[2] in entity): + return True, None + + if len(entities) == 0: + return False, "TokenClassification returns with empty result" + + pipeline_output = hf_pipeline(self.inputs) + for e in pipeline_output: + if e["word"] == entities[0]["word"]: + if e["entity"] == entities[0]["entity"]: + logging.warning( + f"pipeline output differs from expected: {pipeline_output}" + ) + return True, None + else: + break + + logging.error(f"Unexpected inference result: {entities[0]}") + + return False, "Unexpected inference result" diff --git a/serving/docker/convert/zip_utils.py b/serving/docker/convert/zip_utils.py new file mode 100644 index 0000000000..6445f1272c --- /dev/null +++ b/serving/docker/convert/zip_utils.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import os +import zipfile + + +def add_to_zip(path: str, handle: zipfile.ZipFile): + for root, dirs, files in os.walk(path): + for file in files: + file_path = os.path.join(root, file) + entry_name = os.path.relpath(file_path, path) + handle.write(file_path, entry_name) + + +def zip_dir(src_dir: str, output_file: str): + with zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED) as f: + add_to_zip(src_dir, f) diff --git a/serving/docker/lmi.Dockerfile b/serving/docker/lmi.Dockerfile index ef475bc35e..83e8b02beb 100644 --- a/serving/docker/lmi.Dockerfile +++ b/serving/docker/lmi.Dockerfile @@ -69,9 +69,11 @@ COPY scripts scripts/ RUN mkdir -p /opt/djl/conf && \ mkdir -p /opt/djl/deps && \ mkdir -p /opt/djl/partition && \ + mkdir -p /opt/djl/convert && \ mkdir -p /opt/ml/model COPY config.properties /opt/djl/conf/config.properties COPY partition /opt/djl/partition +COPY convert /opt/djl/convert COPY distribution[s]/ ./ RUN mv *.deb djl-serving_all.deb || true @@ -92,7 +94,7 @@ RUN pip3 install torch==${torch_version} torchvision==${torch_vision_version} -- transformers==${transformers_version} hf-transfer zstandard datasets==${datasets_version} \ mpi4py sentencepiece tiktoken blobfile einops accelerate==${accelerate_version} bitsandbytes==${bitsandbytes_version} \ optimum==${optimum_version} auto-gptq==${auto_gptq_version} pandas pyarrow jinja2 \ - opencv-contrib-python-headless safetensors scipy && \ + opencv-contrib-python-headless safetensors scipy onnx onnxruntime sentence_transformers && \ pip3 cache purge RUN pip3 install ${flash_attn_2_wheel} ${lmi_dist_wheel} ${vllm_wheel} pydantic==${pydantic_version} && \ diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 8d7727dd3c..782d6272be 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -91,8 +91,7 @@ static boolean isTrtLlmRollingBatch(Properties properties) { static boolean needConvert(ModelInfo info) { Properties properties = info.getProperties(); - return isTrtLlmRollingBatch(info.getProperties()) - || properties.containsKey("trtllm_python_backend"); + return isTrtLlmRollingBatch(properties) || properties.containsKey("trtllm_python_backend"); } static void convertTrtLLM(ModelInfo info) throws IOException { @@ -136,6 +135,70 @@ static void convertTrtLLM(ModelInfo info) throws IOException { } } + static void convertOnnx(ModelInfo info) throws IOException { + String prefix = info.prop.getProperty("option.modelName", "model"); + if (Files.isRegularFile(info.modelDir.resolve(prefix + ".onnx")) + || Files.isRegularFile(info.modelDir.resolve("model.onnx"))) { + return; + } + + Path repo; + String modelId = null; + if (info.downloadDir != null) { + repo = info.downloadDir; + } else { + repo = info.modelDir; + modelId = info.prop.getProperty("option.model_id"); + if (modelId != null && Files.isDirectory(Paths.get(modelId))) { + repo = Paths.get(modelId); + } + } + + if (modelId == null) { + modelId = repo.toString(); + } + info.modelDir = exportOnnx(modelId, repo); + } + + private static Path exportOnnx(String modelId, Path repoDir) throws IOException { + logger.info("Converting model to onnx artifacts"); + String[] cmd = { + "python", + "/opt/djl/convert/huggingface_importer.py", + "--output-dir", + repoDir.toAbsolutePath().toString(), + "--output-format", + "OnnxRuntime", + "--model-name", + modelId + }; + boolean success = false; + try { + Process exec = new ProcessBuilder(cmd).redirectErrorStream(true).start(); + try (BufferedReader reader = + new BufferedReader( + new InputStreamReader(exec.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + logger.debug("convert: {}", line); + } + } + int exitCode = exec.waitFor(); + if (0 != exitCode) { + throw new EngineException("Model conversion process failed!"); + } + success = true; + logger.info("Onnx artifacts built successfully"); + return repoDir; + } catch (InterruptedException e) { + throw new IOException("Failed to build TensorRT-LLM artifacts", e); + } finally { + if (!success) { + Utils.deleteQuietly(repoDir); + } + } + } + /** * Returns the Huggingface config.json file URI. * diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 76d7928882..0174303ecc 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -526,7 +526,16 @@ public void initialize() throws IOException, ModelException { metric = new Metric("ConvertTrtllm", duration, Unit.MICROSECONDS, dimension); MODEL_METRIC.info("{}", metric); eventManager.onModelConverted(this, "trtllm"); + } else if ("OnnxRuntime".equals(getEngineName())) { + eventManager.onModelConverting(this, "onnx"); + begin = System.nanoTime(); + LmiUtils.convertOnnx(this); + duration = (System.nanoTime() - begin) / 1000; + metric = new Metric("ConvertOnnx", duration, Unit.MICROSECONDS, dimension); + MODEL_METRIC.info("{}", metric); + eventManager.onModelConverted(this, "onnx"); } + // override prop keys are not write to serving.properties, // we have to explicitly set in Criteria if (options == null) { diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index d41f775873..0745b753d0 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -185,6 +185,8 @@ public void testInitModel() throws IOException, ModelException { Path onnx = modelDir.resolve("test_model.onnx"); Files.createFile(onnx); model = new ModelInfo<>("build/models/test_model"); + model.prop = new Properties(); + model.prop.put("option.modelName", "test_model"); model.initialize(); assertEquals(model.getEngineName(), "OnnxRuntime");