Skip to content

Commit

Permalink
Agent-based routing.
Browse files Browse the repository at this point in the history
  • Loading branch information
jondurbin committed Aug 26, 2023
1 parent 509b178 commit 81381aa
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 22 deletions.
132 changes: 111 additions & 21 deletions airoboros/lmoe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
StoppingCriteria,
StoppingCriteriaList,
)
from typing import List, Dict
from typing import List, Dict, Any

warnings.filterwarnings("ignore")
MODEL_LOCK = asyncio.Lock()
MODELS = {}
DESCRIPTIONS = {}
ROLE_MAP = {
"user": "USER",
"assistant": "ASSISTANT",
Expand All @@ -43,12 +44,19 @@
# "\nRemember,"
# "\nPlease note,"
]

# Hacky way to handle variations of differently tokenized values...
USER_STOP_TOKENS = [
torch.tensor([3148, 1001, 29901], device="cuda:0"),
torch.tensor([11889, 29901], device="cuda:0"),
torch.tensor([3148, 1001, 29901], device="cuda"),
torch.tensor([11889, 29901], device="cuda"),
]
ROUTING_PROMPT_TEMPLATE = """A chat.
USER: As an AI assistant, choose the correct function from the list of available functions below, according to the user's request. Your response should be in JSON format.
Input: {instruction}
Available functions:
{functions}
ASSISTANT: """

app = fastapi.FastAPI()

Expand Down Expand Up @@ -94,6 +102,70 @@ async def list_models():
}


def route_via_agent(model: Any, request: ChatRequest, stopping_criteria: Any) -> str:
"""Route a request using the LLM with the adapter descriptions."""
loaded_expert = getattr(model, "__expert__", None)
if loaded_expert != "function":
model.set_adapter("function")
setattr(model, "__expert__", "function")

# We'll just use the system prompt and last message for the instruction.
instruction = " ".join(
[
request.messages[0]["content"].strip(),
request.messages[-1]["content"],
]
)
functions = "\n".join(
[
f"{name}:\n description: {description}"
for name, description in DESCRIPTIONS.items()
]
)
prompt = ROUTING_PROMPT_TEMPLATE.format(
instruction=instruction, functions=functions
)
input_ids = MODELS["__tokenizer__"](prompt, return_tensors="pt")["input_ids"].to(
"cuda"
)
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=False
):
outputs = model.generate(
input_ids=input_ids,
stopping_criteria=stopping_criteria,
max_new_tokens=16,
temperature=0.3,
top_p=0.8,
top_k=50,
use_cache=False,
early_stopping=True,
)
response = (
MODELS["__tokenizer__"]
.batch_decode(
outputs.detach().cpu().numpy(),
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
.split("ASSISTANT:")[-1]
.strip()
)
response = re.sub(r"[\s\n]*(USER|ASSISTANT):\s*$", "", response, re.DOTALL)
result = re.search(r'"function":\s*"([^"]+)"', response, re.I)
if not result:
result = re.search(
'"(' + "|".join([re.compile(name) for name in DESCRIPTIONS]) + ')"',
response,
re.I,
)
if result:
expert = result.group(1)
logger.info(f"Agent-based routing selection: {expert}")
return expert
return None


def complete_request(request):
"""Sync method to complete a request, to make sure we aren't message with model/LoRAs concurrently."""
if any(
Expand Down Expand Up @@ -158,16 +230,6 @@ def complete_request(request):
detail="Prompt length + max_tokens exceeds max model length.",
)

# Route the request to the appropriate expert (LoRA).
started_at = datetime.datetime.utcnow()
expert = MODELS[request.model]["router"].route(prompt)
model = MODELS[request.model]["model"]
loaded_expert = getattr(model, "__expert__", None)
if loaded_expert != expert:
model.set_adapter(expert)
setattr(model, "__expert__", expert)
routing_duration = (datetime.datetime.utcnow() - started_at).total_seconds()

# Update our stopping criteria.
stop_words = request.stop or DEFAULT_STOPS
stopping_criteria = None
Expand All @@ -182,6 +244,25 @@ def complete_request(request):
[StoppingCriteriaSub(stops=stop_words_ids)]
)

# Route the request to the appropriate expert (LoRA).
started_at = datetime.datetime.utcnow()
model = MODELS[request.model]["model"]
if "router" in MODELS[request.model]:
expert = MODELS[request.model]["router"].route(prompt)
else:
expert = route_via_agent(model, request, stopping_criteria)
if not expert or expert not in DESCRIPTIONS:
logger.warning("Error performing expert selection, using default")
expert = "reasoning"

# Load the adapter.
model = MODELS[request.model]["model"]
loaded_expert = getattr(model, "__expert__", None)
if loaded_expert != expert:
model.set_adapter(expert)
setattr(model, "__expert__", expert)
routing_duration = (datetime.datetime.utcnow() - started_at).total_seconds()

# Generate the response.
started_at = datetime.datetime.utcnow()
with torch.backends.cuda.sdp_kernel(
Expand Down Expand Up @@ -303,6 +384,12 @@ def main():
default="thenlper/gte-small",
help="model to use for embeddings in expert router",
)
parser.add_argument(
"-m",
"--embedding-router",
action="store_true",
help="use the training data to route requests via similarity search, rather than agent routing",
)
parser.add_argument(
"-s",
"--router-max-samples",
Expand Down Expand Up @@ -358,24 +445,27 @@ def main():
)
.to_bettertransformer()
.eval(),
os.path.abspath(os.path.join(lmoe, "adapters", "general")),
adapter_name="general",
os.path.abspath(os.path.join(lmoe, "adapters", "function")),
adapter_name="function",
),
"router": Router(
}
if args.embedding_router:
MODELS[base_name]["router"] = Router(
model_name_or_path=args.router_model,
input_paths=routing_paths,
max_samples=args.router_max_samples,
k=args.router_k,
),
}
)
logger.info(
f"Loading adapters for {base_name} from {lmoe}: this too is slow..."
f"Loading adapters for {base_name} from {lmoe}: activating all thrusters..."
)
for path in tqdm(glob.glob(os.path.join(lmoe, "adapters", "*"))):
name = os.path.basename(str(path))
if name == "general":
if name == "function":
continue
MODELS[base_name]["model"].load_adapter(str(path), name)
with open(os.path.join(str(path), "description.txt")) as infile:
DESCRIPTIONS[name] = infile.read().strip()

# Start the API server.
uvicorn.run(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="airoboros",
version="2.1.6",
version="2.1.7",
description="Updated and improved implementation of the self-instruct system.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 81381aa

Please sign in to comment.