From 11a88a6cfc58cd564abaf98c13f84e32628f6bb0 Mon Sep 17 00:00:00 2001 From: Isaac Ong Date: Sun, 7 Jul 2024 12:18:40 -0700 Subject: [PATCH] Make controller simpler to import and use --- README.md | 8 +++---- examples/routing_to_local_models.md | 6 ++--- routellm/controller.py | 19 ++++++++++----- routellm/evals/evaluate.py | 27 ++++++++++++++-------- routellm/evals/gsm8k/generate_responses.py | 2 +- routellm/evals/mmlu/generate_responses.py | 2 +- routellm/model_pair.py | 7 ------ routellm/openai_server.py | 5 ++-- routellm/tests/test_openai_client.py | 7 ++---- 9 files changed, 42 insertions(+), 41 deletions(-) delete mode 100644 routellm/model_pair.py diff --git a/README.md b/README.md index 65e6ba5..e50226e 100644 --- a/README.md +++ b/README.md @@ -43,11 +43,9 @@ os.environ["ANYSCALE_API_KEY"] = "esecret_XXXXXX" # client = OpenAI() client = Controller( routers=["mf"], - routed_pair=ModelPair( - strong="gpt-4-1106-preview", - # Mixtral 8x7B model provided by Anyscale - weak="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", - ), + strong_model="gpt-4-1106-preview", + # Mixtral 8x7B model provided by Anyscale + weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", ) ``` Above, we pick `gpt-4-1106-preview` as the strong model and `anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1` as the weak model, setting the API keys accordingly. You can route between different model pairs or providers by updating the model names as described in [Model Support](#model-support). diff --git a/examples/routing_to_local_models.md b/examples/routing_to_local_models.md index d5fc334..b958f85 100644 --- a/examples/routing_to_local_models.md +++ b/examples/routing_to_local_models.md @@ -20,10 +20,8 @@ os.environ["OPENAI_API_KEY"] = "sk-XXXXXX" client = Controller( routers=["mf"], - routed_pair=ModelPair( - strong="gpt-4-1106-preview", - weak="ollama_chat/llama3", - ), + strong_model="gpt-4-1106-preview", + weak_model="ollama_chat/llama3", ) ``` diff --git a/routellm/controller.py b/routellm/controller.py index fa92e84..d1bc75a 100644 --- a/routellm/controller.py +++ b/routellm/controller.py @@ -1,4 +1,5 @@ from collections import defaultdict +from dataclasses import dataclass from types import SimpleNamespace from typing import Any, Optional @@ -6,7 +7,6 @@ import tqdm from litellm import acompletion, completion -from routellm.model_pair import ModelPair from routellm.routers.routers import ROUTER_CLS # Default config for routers augmented using golden label data from GPT-4. @@ -32,17 +32,24 @@ class RoutingError(Exception): pass +@dataclass +class ModelPair: + strong: str + weak: str + + class Controller: def __init__( self, routers: list[str], - routed_pair: ModelPair, - progress_bar: bool = False, + strong_model: str, + weak_model: str, config: Optional[dict[str, dict[str, Any]]] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, + progress_bar: bool = False, ): - self.routed_pair = routed_pair + self.model_pair = ModelPair(strong=strong_model, weak=weak_model) self.routers = {} self.api_base = api_base self.api_key = api_key @@ -95,7 +102,7 @@ def _get_routed_model_for_completion( # Look at the last turn for routing. # Our current routers were only trained on first turn data, so more research is required here. prompt = messages[-1]["content"] - routed_model = self.routers[router].route(prompt, threshold, self.routed_pair) + routed_model = self.routers[router].route(prompt, threshold, self.model_pair) self.model_counts[router][routed_model] += 1 @@ -117,7 +124,7 @@ def batch_calculate_win_rate( def route(self, prompt: str, router: str, threshold: float): self._validate_router_threshold(router, threshold) - return self.routers[router].route(prompt, threshold, self.routed_pair) + return self.routers[router].route(prompt, threshold, self.model_pair) # Matches OpenAI's Chat Completions interface, but also supports optional router and threshold args # If model name is present, attempt to parse router and threshold using it, otherwise, use the router and threshold args diff --git a/routellm/evals/evaluate.py b/routellm/evals/evaluate.py index 23fcadd..d0eb2f7 100644 --- a/routellm/evals/evaluate.py +++ b/routellm/evals/evaluate.py @@ -11,7 +11,6 @@ from routellm.controller import Controller from routellm.evals.benchmarks import GSM8K, MMLU, MTBench from routellm.evals.mmlu.domains import ALL_MMLU_DOMAINS -from routellm.model_pair import ModelPair from routellm.routers.routers import ROUTER_CLS os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -195,24 +194,24 @@ def pretty_print_results(threshold, accuracy, model_counts, total): print(args) pandarallel.initialize(progress_bar=True, nb_workers=args.parallel) - routed_pair = ModelPair(strong=args.strong_model, weak=args.weak_model) controller = Controller( routers=args.routers, config=yaml.safe_load(open(args.config, "r")) if args.config else None, - routed_pair=routed_pair, + strong_model=args.strong_model, + weak_model=args.weak_model, progress_bar=True, ) if args.benchmark == "mmlu": print("Running eval for full MMLU.") mmlu_domains = ALL_MMLU_DOMAINS - benchmark = MMLU(mmlu_domains, routed_pair, args.overwrite_cache) + benchmark = MMLU(mmlu_domains, controller.model_pair, args.overwrite_cache) elif args.benchmark == "mt-bench": print("Running eval for MT Bench.") - benchmark = MTBench(routed_pair, args.overwrite_cache) + benchmark = MTBench(controller.model_pair, args.overwrite_cache) elif args.benchmark == "gsm8k": print("Running eval for GSM8k.") - benchmark = GSM8K(routed_pair, args.overwrite_cache) + benchmark = GSM8K(controller.model_pair, args.overwrite_cache) else: raise ValueError(f"Invalid benchmark {args.benchmark}") @@ -230,7 +229,9 @@ def pretty_print_results(threshold, accuracy, model_counts, total): router_results.append( { "threshold": threshold, - "strong_percentage": model_counts[routed_pair.strong] + "strong_percentage": model_counts[ + controller.model_pair.strong + ] / total * 100, "accuracy": accuracy, @@ -254,10 +255,18 @@ def pretty_print_results(threshold, accuracy, model_counts, total): result = { "method": str(router), "threshold": threshold, - "strong_percentage": model_counts[routed_pair.strong] / total * 100, + "strong_percentage": model_counts[controller.model_pair.strong] + / total + * 100, "accuracy": accuracy, } router_results.append(result) all_results = pd.concat([all_results, pd.DataFrame(router_results)]) - generate_results(all_results, benchmark, args.benchmark, routed_pair, args.output) + generate_results( + all_results, + benchmark, + args.benchmark, + controller.model_pair.strong, + args.output, + ) diff --git a/routellm/evals/gsm8k/generate_responses.py b/routellm/evals/gsm8k/generate_responses.py index 7ca7583..ee2ed79 100644 --- a/routellm/evals/gsm8k/generate_responses.py +++ b/routellm/evals/gsm8k/generate_responses.py @@ -9,7 +9,7 @@ import pandas as pd from openai import OpenAI -from routellm.model_pair import ModelPair +from routellm.controller import ModelPair """ The core code is based heavily on the original SGLang implementation. diff --git a/routellm/evals/mmlu/generate_responses.py b/routellm/evals/mmlu/generate_responses.py index 707a9a2..77bf336 100644 --- a/routellm/evals/mmlu/generate_responses.py +++ b/routellm/evals/mmlu/generate_responses.py @@ -10,8 +10,8 @@ import tqdm from openai import OpenAI +from routellm.controller import ModelPair from routellm.evals.mmlu.domains import ALL_MMLU_DOMAINS -from routellm.model_pair import ModelPair ROUTED_PAIR = ModelPair( strong="gpt-4-1106-preview", weak="mistralai/Mixtral-8x7B-Instruct-v0.1" diff --git a/routellm/model_pair.py b/routellm/model_pair.py deleted file mode 100644 index 2ff5ffc..0000000 --- a/routellm/model_pair.py +++ /dev/null @@ -1,7 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class ModelPair: - strong: str - weak: str diff --git a/routellm/openai_server.py b/routellm/openai_server.py index f2761f0..7df4192 100644 --- a/routellm/openai_server.py +++ b/routellm/openai_server.py @@ -20,7 +20,6 @@ from pydantic import BaseModel, Field from routellm.controller import Controller, RoutingError -from routellm.model_pair import ModelPair from routellm.routers.routers import ROUTER_CLS os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -34,11 +33,11 @@ async def lifespan(app): global CONTROLLER - routed_pair = ModelPair(strong=args.strong_model, weak=args.weak_model) CONTROLLER = Controller( routers=args.routers, config=yaml.safe_load(open(args.config, "r")) if args.config else None, - routed_pair=routed_pair, + strong_model=args.strong_model, + weak_model=args.weak_model, api_base=args.base_url, api_key=args.api_key, progress_bar=True, diff --git a/routellm/tests/test_openai_client.py b/routellm/tests/test_openai_client.py index 18fc19a..0f4009b 100644 --- a/routellm/tests/test_openai_client.py +++ b/routellm/tests/test_openai_client.py @@ -1,7 +1,6 @@ import argparse from routellm.controller import Controller -from routellm.model_pair import ModelPair from routellm.routers.routers import ROUTER_CLS system_content = ( @@ -32,10 +31,8 @@ client = Controller( routers=[args.router], - routed_pair=ModelPair( - strong="gpt-4-1106-preview", - weak="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", - ), + strong_model="gpt-4-1106-preview", + weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1", ) chat_completion = client.chat.completions.create(