Skip to content

Commit

Permalink
Make controller simpler to import and use
Browse files Browse the repository at this point in the history
  • Loading branch information
iojw committed Jul 7, 2024
1 parent 82a1410 commit 11a88a6
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 41 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ os.environ["ANYSCALE_API_KEY"] = "esecret_XXXXXX"
# client = OpenAI()
client = Controller(
routers=["mf"],
routed_pair=ModelPair(
strong="gpt-4-1106-preview",
# Mixtral 8x7B model provided by Anyscale
weak="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
),
strong_model="gpt-4-1106-preview",
# Mixtral 8x7B model provided by Anyscale
weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
)
```
Above, we pick `gpt-4-1106-preview` as the strong model and `anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1` as the weak model, setting the API keys accordingly. You can route between different model pairs or providers by updating the model names as described in [Model Support](#model-support).
Expand Down
6 changes: 2 additions & 4 deletions examples/routing_to_local_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ os.environ["OPENAI_API_KEY"] = "sk-XXXXXX"

client = Controller(
routers=["mf"],
routed_pair=ModelPair(
strong="gpt-4-1106-preview",
weak="ollama_chat/llama3",
),
strong_model="gpt-4-1106-preview",
weak_model="ollama_chat/llama3",
)
```

Expand Down
19 changes: 13 additions & 6 deletions routellm/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Any, Optional

import pandas as pd
import tqdm
from litellm import acompletion, completion

from routellm.model_pair import ModelPair
from routellm.routers.routers import ROUTER_CLS

# Default config for routers augmented using golden label data from GPT-4.
Expand All @@ -32,17 +32,24 @@ class RoutingError(Exception):
pass


@dataclass
class ModelPair:
strong: str
weak: str


class Controller:
def __init__(
self,
routers: list[str],
routed_pair: ModelPair,
progress_bar: bool = False,
strong_model: str,
weak_model: str,
config: Optional[dict[str, dict[str, Any]]] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
progress_bar: bool = False,
):
self.routed_pair = routed_pair
self.model_pair = ModelPair(strong=strong_model, weak=weak_model)
self.routers = {}
self.api_base = api_base
self.api_key = api_key
Expand Down Expand Up @@ -95,7 +102,7 @@ def _get_routed_model_for_completion(
# Look at the last turn for routing.
# Our current routers were only trained on first turn data, so more research is required here.
prompt = messages[-1]["content"]
routed_model = self.routers[router].route(prompt, threshold, self.routed_pair)
routed_model = self.routers[router].route(prompt, threshold, self.model_pair)

self.model_counts[router][routed_model] += 1

Expand All @@ -117,7 +124,7 @@ def batch_calculate_win_rate(
def route(self, prompt: str, router: str, threshold: float):
self._validate_router_threshold(router, threshold)

return self.routers[router].route(prompt, threshold, self.routed_pair)
return self.routers[router].route(prompt, threshold, self.model_pair)

# Matches OpenAI's Chat Completions interface, but also supports optional router and threshold args
# If model name is present, attempt to parse router and threshold using it, otherwise, use the router and threshold args
Expand Down
27 changes: 18 additions & 9 deletions routellm/evals/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from routellm.controller import Controller
from routellm.evals.benchmarks import GSM8K, MMLU, MTBench
from routellm.evals.mmlu.domains import ALL_MMLU_DOMAINS
from routellm.model_pair import ModelPair
from routellm.routers.routers import ROUTER_CLS

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -195,24 +194,24 @@ def pretty_print_results(threshold, accuracy, model_counts, total):
print(args)

pandarallel.initialize(progress_bar=True, nb_workers=args.parallel)
routed_pair = ModelPair(strong=args.strong_model, weak=args.weak_model)
controller = Controller(
routers=args.routers,
config=yaml.safe_load(open(args.config, "r")) if args.config else None,
routed_pair=routed_pair,
strong_model=args.strong_model,
weak_model=args.weak_model,
progress_bar=True,
)

if args.benchmark == "mmlu":
print("Running eval for full MMLU.")
mmlu_domains = ALL_MMLU_DOMAINS
benchmark = MMLU(mmlu_domains, routed_pair, args.overwrite_cache)
benchmark = MMLU(mmlu_domains, controller.model_pair, args.overwrite_cache)
elif args.benchmark == "mt-bench":
print("Running eval for MT Bench.")
benchmark = MTBench(routed_pair, args.overwrite_cache)
benchmark = MTBench(controller.model_pair, args.overwrite_cache)
elif args.benchmark == "gsm8k":
print("Running eval for GSM8k.")
benchmark = GSM8K(routed_pair, args.overwrite_cache)
benchmark = GSM8K(controller.model_pair, args.overwrite_cache)
else:
raise ValueError(f"Invalid benchmark {args.benchmark}")

Expand All @@ -230,7 +229,9 @@ def pretty_print_results(threshold, accuracy, model_counts, total):
router_results.append(
{
"threshold": threshold,
"strong_percentage": model_counts[routed_pair.strong]
"strong_percentage": model_counts[
controller.model_pair.strong
]
/ total
* 100,
"accuracy": accuracy,
Expand All @@ -254,10 +255,18 @@ def pretty_print_results(threshold, accuracy, model_counts, total):
result = {
"method": str(router),
"threshold": threshold,
"strong_percentage": model_counts[routed_pair.strong] / total * 100,
"strong_percentage": model_counts[controller.model_pair.strong]
/ total
* 100,
"accuracy": accuracy,
}
router_results.append(result)
all_results = pd.concat([all_results, pd.DataFrame(router_results)])

generate_results(all_results, benchmark, args.benchmark, routed_pair, args.output)
generate_results(
all_results,
benchmark,
args.benchmark,
controller.model_pair.strong,
args.output,
)
2 changes: 1 addition & 1 deletion routellm/evals/gsm8k/generate_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from openai import OpenAI

from routellm.model_pair import ModelPair
from routellm.controller import ModelPair

"""
The core code is based heavily on the original SGLang implementation.
Expand Down
2 changes: 1 addition & 1 deletion routellm/evals/mmlu/generate_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import tqdm
from openai import OpenAI

from routellm.controller import ModelPair
from routellm.evals.mmlu.domains import ALL_MMLU_DOMAINS
from routellm.model_pair import ModelPair

ROUTED_PAIR = ModelPair(
strong="gpt-4-1106-preview", weak="mistralai/Mixtral-8x7B-Instruct-v0.1"
Expand Down
7 changes: 0 additions & 7 deletions routellm/model_pair.py

This file was deleted.

5 changes: 2 additions & 3 deletions routellm/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pydantic import BaseModel, Field

from routellm.controller import Controller, RoutingError
from routellm.model_pair import ModelPair
from routellm.routers.routers import ROUTER_CLS

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -34,11 +33,11 @@
async def lifespan(app):
global CONTROLLER

routed_pair = ModelPair(strong=args.strong_model, weak=args.weak_model)
CONTROLLER = Controller(
routers=args.routers,
config=yaml.safe_load(open(args.config, "r")) if args.config else None,
routed_pair=routed_pair,
strong_model=args.strong_model,
weak_model=args.weak_model,
api_base=args.base_url,
api_key=args.api_key,
progress_bar=True,
Expand Down
7 changes: 2 additions & 5 deletions routellm/tests/test_openai_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse

from routellm.controller import Controller
from routellm.model_pair import ModelPair
from routellm.routers.routers import ROUTER_CLS

system_content = (
Expand Down Expand Up @@ -32,10 +31,8 @@

client = Controller(
routers=[args.router],
routed_pair=ModelPair(
strong="gpt-4-1106-preview",
weak="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
),
strong_model="gpt-4-1106-preview",
weak_model="anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
)

chat_completion = client.chat.completions.create(
Expand Down

0 comments on commit 11a88a6

Please sign in to comment.