Skip to content

Commit

Permalink
Tree: Add generation logging support
Browse files Browse the repository at this point in the history
Generations can be logged in the console along with sampling parameters
if the user enables it in config.

Metrics are always logged at the end of each prompt. In addition,
the model endpoint tells the user if they're being logged or not
for transparancy purposes.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Dec 13, 2023
1 parent b364de1 commit 083df7d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 10 deletions.
2 changes: 2 additions & 0 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional
from gen_logging import LogConfig

class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096
Expand All @@ -14,6 +15,7 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[LogConfig] = None
parameters: Optional[ModelCardParameters] = None

class ModelList(BaseModel):
Expand Down
8 changes: 8 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ network:
# The port to host on (default: 5000)
port: 5000

# Options for logging
logging:
# Enable prompt logging (default: False)
prompt: False

# Enable generation parameter logging (default: False)
generation_params: False

# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)
Expand Down
47 changes: 47 additions & 0 deletions gen_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict
from pydantic import BaseModel

# Logging preference config
class LogConfig(BaseModel):
prompt: bool = False
generation_params: bool = False

# Global reference to logging preferences
config = LogConfig()

# Wrapper to set the logging config for generations
def update_from_dict(options_dict: Dict[str, bool]):
global config

# Force bools on the dict
for value in options_dict.values():
if value is None:
value = False

config = LogConfig.parse_obj(options_dict)

def broadcast_status():
enabled = []
if config.prompt:
enabled.append("prompts")

if config.generation_params:
enabled.append("generation params")

if len(enabled) > 0:
print("Generation logging is enabled for: " + ", ".join(enabled))
else:
print("Generation logging is disabled")

# Logs generation parameters to console
def log_generation_params(**kwargs):
if config.generation_params:
print(f"Generation options: {kwargs}\n")

def log_prompt(prompt: str):
if config.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")

def log_response(response: str):
if config.prompt:
print(f"Response: {response if response else 'Empty'}\n")
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uvicorn
import yaml
import pathlib
import gen_logging
from asyncio import CancelledError
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
Expand Down Expand Up @@ -81,7 +82,8 @@ async def get_current_model():
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
prompt_template = unwrap(model_container.prompt_template, "auto")
)
),
logging = gen_logging.config
)

if model_container.draft_config:
Expand Down Expand Up @@ -370,6 +372,13 @@ async def generator():
)
config = {}

# Override the generation log options if given
log_config = unwrap(config.get("logging"), {})
if log_config:
gen_logging.update_from_dict(log_config)

gen_logging.broadcast_status()

# If an initial model name is specified, create a container and load the model
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:
Expand Down
33 changes: 24 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from typing import List, Optional, Union
from utils import coalesce, unwrap
from gen_logging import log_generation_params, log_prompt, log_response

# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
Expand Down Expand Up @@ -351,21 +352,32 @@ def generate_gen(self, prompt: str, **kwargs):
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)


# Ban the EOS token if specified. If not, append to stop conditions as well.
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
else:
stop_conditions.append(self.tokenizer.eos_token_id)

# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
gen_settings.temperature = 1.0
gen_settings.top_k = 1
gen_settings.top_p = 0
gen_settings.typical = 0

# Stop conditions
# Log generation options to console
log_generation_params(
**vars(gen_settings),
token_healing = token_healing,
max_tokens = max_tokens,
stop_conditions = stop_conditions
)

# Log prompt to console
log_prompt(prompt)

# Ban the EOS token if specified. If not, append to stop conditions as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
else:
stop_conditions.append(self.tokenizer.eos_token_id)

# Stop conditions
self.generator.set_stop_conditions(stop_conditions)

# Tokenized context
Expand Down Expand Up @@ -430,9 +442,12 @@ def generate_gen(self, prompt: str, **kwargs):

if eos or generated_tokens == max_tokens: break

# Print response
log_response(full_response)

elapsed_time = last_chunk_time - start_time

initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
itemization = []
extra_parts = []

Expand Down

0 comments on commit 083df7d

Please sign in to comment.