diff --git a/examples/custom/model.py b/examples/custom/model.py index 771a5150e7..c87e6e87e7 100644 --- a/examples/custom/model.py +++ b/examples/custom/model.py @@ -3,8 +3,11 @@ get_model_tokenizer_with_flash_attn, register_model, register_template) register_template( - TemplateMeta('custom', ['System\n{{SYSTEM}}\n'], - ['User\n{{QUERY}}\nAssistant\n'], ['\n'])) + TemplateMeta( + template_type='custom', + prefix=['System\n{{SYSTEM}}\n'], + prompt=['User\n{{QUERY}}\nAssistant\n'], + chat_sep=['\n'])) register_model( ModelMeta( @@ -17,7 +20,7 @@ ignore_file_pattern=['nemo'])) if __name__ == '__main__': - infer_request = InferRequest(messages=[{'role': 'user', 'content': '你是谁'}]) + infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}]) request_config = RequestConfig(max_tokens=512, temperature=0) engine = PtEngine('AI-ModelScope/Nemotron-Mini-4B-Instruct') response = engine.infer([infer_request], request_config) @@ -27,3 +30,4 @@ response = engine.infer([infer_request], request_config) jinja_response = response[0].choices[0].message.content assert swift_response == jinja_response, (f'swift_response: {swift_response}\njinja_response: {jinja_response}') + print(f'response: {swift_response}')