From 010ff8c11db4e2d3b4812ec8bdf6b6ef94442687 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 29 Oct 2024 17:56:26 +0100 Subject: [PATCH] feat: use custom http_client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds the ability to pass a custom HTTP client to the MT-Bench evaluator. This is handy when using custom certificates when interacting with the judge model serving endpoint. Signed-off-by: Sébastien Han --- .spellcheck-en-custom.txt | 1 + CHANGELOG.md | 4 ++++ requirements.txt | 1 + src/instructlab/eval/mt_bench.py | 15 +++++++++++++++ src/instructlab/eval/mt_bench_answers.py | 3 ++- src/instructlab/eval/mt_bench_common.py | 11 +++++++++-- src/instructlab/eval/mt_bench_judgment.py | 3 ++- tests/test_branch_gen_answers.py | 8 +++++++- 8 files changed, 41 insertions(+), 5 deletions(-) diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index 769b05b..33a582f 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -10,6 +10,7 @@ dr eval gpt hoc +http instructlab jsonl justfile diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ba63cc..a897297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,5 @@ +## 0.4 + +* Added ability to specify a custom http client to MT-Bench + ## v0.2 diff --git a/requirements.txt b/requirements.txt index 9be7cbd..a3e6e7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ accelerate pandas pandas-stubs lm-eval>=0.4.4 +httpx diff --git a/src/instructlab/eval/mt_bench.py b/src/instructlab/eval/mt_bench.py index 2d9a12a..3b24aa8 100644 --- a/src/instructlab/eval/mt_bench.py +++ b/src/instructlab/eval/mt_bench.py @@ -10,6 +10,9 @@ import multiprocessing import os +# Third Party +import httpx + # First Party from instructlab.eval import ( mt_bench_answers, @@ -110,6 +113,7 @@ def gen_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> None: """ Asks questions to model @@ -119,6 +123,7 @@ def gen_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests """ logger.debug(locals()) mt_bench_answers.generate_answers( @@ -127,6 +132,7 @@ def gen_answers( api_key=api_key, output_dir=self.output_dir, max_workers=self._get_effective_max_workers(max_workers, serving_gpus), + http_client=http_client, ) def judge_answers( @@ -135,6 +141,7 @@ def judge_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> tuple: """ Runs MT-Bench judgment @@ -144,6 +151,7 @@ def judge_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests Returns: overall_score MT-Bench score for the overall model evaluation @@ -160,6 +168,7 @@ def judge_answers( max_workers=self._get_effective_max_workers(max_workers, serving_gpus), output_dir=self.output_dir, merge_system_user_message=self.merge_system_user_message, + http_client=http_client, ) @@ -202,6 +211,7 @@ def gen_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> None: """ Asks questions to model @@ -211,6 +221,7 @@ def gen_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests """ logger.debug(locals()) mt_bench_branch_generator.generate( @@ -228,6 +239,7 @@ def gen_answers( data_dir=self.output_dir, max_workers=self._get_effective_max_workers(max_workers, serving_gpus), bench_name="mt_bench_branch", + http_client=http_client, ) def judge_answers( @@ -236,6 +248,7 @@ def judge_answers( api_key: str | None = None, max_workers: int | str | None = None, serving_gpus: int | None = None, + http_client: httpx.Client | None = None, ) -> tuple: """ Runs MT-Bench-Branch judgment. Judgments can be compared across runs with consistent question_id -> qna file name. @@ -245,6 +258,7 @@ def judge_answers( api_key API token for authenticating with model server max_workers Max parallel workers to run the evaluation with (int or "auto"). None indicates to use value specified in constructor. serving_gpus Number of gpus allocated for serving. Used to tune with max_workers=auto. None indicates to use value specified in constructor. + http_client Custom http client to use for requests Returns: overall_score Overall score from the evaluation @@ -263,5 +277,6 @@ def judge_answers( data_dir=self.output_dir, bench_name="mt_bench_branch", merge_system_user_message=self.merge_system_user_message, + http_client=http_client, ) return overall_score, qa_pairs, error_rate diff --git a/src/instructlab/eval/mt_bench_answers.py b/src/instructlab/eval/mt_bench_answers.py index ac6b98b..f4337b4 100644 --- a/src/instructlab/eval/mt_bench_answers.py +++ b/src/instructlab/eval/mt_bench_answers.py @@ -108,11 +108,12 @@ def generate_answers( max_tokens=1024, max_workers=1, bench_name="mt_bench", + http_client=None, ): """Generate model answers to be judged""" logger.debug(locals()) - openai_client = get_openai_client(model_api_base, api_key) + openai_client = get_openai_client(model_api_base, api_key, http_client) if data_dir is None: data_dir = os.path.join(os.path.dirname(__file__), "data") diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index f45bf5f..f3f8068 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -13,6 +13,7 @@ import time # Third Party +import httpx import openai # First Party @@ -365,8 +366,14 @@ def get_model_list(answer_file): return [os.path.splitext(os.path.basename(answer_file))[0]] -def get_openai_client(model_api_base, api_key): +def get_openai_client( + model_api_base, + api_key, + http_client: httpx.Client | None = None, +): if api_key is None: api_key = "NO_API_KEY" - openai_client = openai.OpenAI(base_url=model_api_base, api_key=api_key) + openai_client = openai.OpenAI( + base_url=model_api_base, api_key=api_key, http_client=http_client + ) return openai_client diff --git a/src/instructlab/eval/mt_bench_judgment.py b/src/instructlab/eval/mt_bench_judgment.py index 53ba315..f853a09 100644 --- a/src/instructlab/eval/mt_bench_judgment.py +++ b/src/instructlab/eval/mt_bench_judgment.py @@ -286,11 +286,12 @@ def generate_judgment( max_workers=1, first_n=None, merge_system_user_message=False, + http_client=None, ): """Generate judgment with scores and qa_pairs for a model""" logger.debug(locals()) - openai_client = get_openai_client(model_api_base, api_key) + openai_client = get_openai_client(model_api_base, api_key, http_client) first_n_env = os.environ.get("INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS") if first_n_env is not None and first_n is None: diff --git a/tests/test_branch_gen_answers.py b/tests/test_branch_gen_answers.py index 04f85ac..e5e0d36 100755 --- a/tests/test_branch_gen_answers.py +++ b/tests/test_branch_gen_answers.py @@ -1,3 +1,6 @@ +# Third Party +import httpx + # First Party from instructlab.eval.mt_bench import MTBenchBranchEvaluator @@ -7,4 +10,7 @@ "../taxonomy", "main", ) -mt_bench_branch.gen_answers("http://localhost:8000/v1") +mt_bench_branch.gen_answers( + "http://localhost:8000/v1", + http_client=httpx.Client(verify=False), +)