Skip to content

Commit

Permalink
log model calls when model providers return bad request errors (#815)
Browse files Browse the repository at this point in the history
Co-authored-by: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com>
  • Loading branch information
jjallaire-aisi and aisi-inspect authored Nov 6, 2024
1 parent 91c5a3a commit de06eb6
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fix issue with correctly logging task_args for eval-set tasks which are interrupted.
- Move `INSPECT_DISABLE_MODEL_API` into `generate()` (as opposed to `get_model()`)
- Always treat `.eval` files as logs (don't apply file name pattern restrictions as we do with `.json`).
- Log model calls when model providers return bad request errors

## v0.3.44 (04 November 2024)

Expand Down
26 changes: 17 additions & 9 deletions src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ async def generate(
tool_choice: ToolChoice,
config: GenerateConfig,
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
# setup request and response for ModelCall
request: dict[str, Any] = {}
response: dict[str, Any] = {}

def model_call() -> ModelCall:
return ModelCall.create(
request=request,
response=response,
filter=model_call_filter,
)

# generate
try:
(
Expand All @@ -135,7 +146,7 @@ async def generate(
) = await resolve_chat_input(self.model_name, input, tools, config)

# prepare request params (assembed this way so we can log the raw model call)
request: dict[str, Any] = dict(messages=messages)
request = dict(messages=messages)

# system messages and tools
if system_param is not None:
Expand All @@ -156,22 +167,19 @@ async def generate(
# call model
message = await self.client.messages.create(**request, stream=False)

# set response for ModelCall
response = message.model_dump()

# extract output
output = model_output_from_message(message, tools)

# return output and call
call = ModelCall.create(
request=request,
response=message.model_dump(),
filter=model_call_filter,
)

return output, call
return output, model_call()

except BadRequestError as ex:
error_output = self.handle_bad_request(ex)
if error_output is not None:
return error_output
return error_output, model_call()
else:
raise ex

Expand Down
33 changes: 18 additions & 15 deletions src/inspect_ai/model/_providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ async def generate(
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None

# response for ModelCall
response: AsyncGenerateContentResponse | None = None

def model_call() -> ModelCall:
return build_model_call(
contents=contents,
safety_settings=self.safety_settings,
generation_config=parameters,
tools=gemini_tools,
tool_config=gemini_tool_config,
response=response,
)

# cast to AsyncGenerateContentResponse since we passed stream=False
try:
response = cast(
Expand All @@ -150,7 +163,7 @@ async def generate(
),
)
except InvalidArgument as ex:
return self.handle_invalid_argument(ex)
return self.handle_invalid_argument(ex), model_call()

# build output
output = ModelOutput(
Expand All @@ -163,18 +176,8 @@ async def generate(
),
)

# build call
call = model_call(
contents=contents,
safety_settings=self.safety_settings,
generation_config=parameters,
tools=gemini_tools,
tool_config=gemini_tool_config,
response=response,
)

# return
return output, call
return output, model_call()

def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
if "size exceeds the limit" in ex.message.lower():
Expand All @@ -197,13 +200,13 @@ def connection_key(self) -> str:
return self.model_name


def model_call(
def build_model_call(
contents: list[ContentDict],
generation_config: GenerationConfig,
safety_settings: SafetySettingDict,
tools: list[Tool] | None,
tool_config: ToolConfig | None,
response: AsyncGenerateContentResponse,
response: AsyncGenerateContentResponse | None,
) -> ModelCall:
return ModelCall.create(
request=dict(
Expand All @@ -217,7 +220,7 @@ def model_call(
if tool_config is not None
else None,
),
response=response.to_dict(),
response=response.to_dict() if response is not None else {},
filter=model_call_filter,
)

Expand Down
36 changes: 23 additions & 13 deletions src/inspect_ai/model/_providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ async def generate(
**self.completion_params(config, False),
)

# setup request and response for ModelCall
request: dict[str, Any] = {}
response: dict[str, Any] = {}

def model_call() -> ModelCall:
return ModelCall.create(
request=request,
response=response,
filter=image_url_filter,
)

# unlike text models, vision models require a max_tokens (and set it to a very low
# default, see https://community.openai.com/t/gpt-4-vision-preview-finish-details/475911/10)
OPENAI_IMAGE_DEFAULT_TOKENS = 4096
Expand All @@ -176,33 +187,32 @@ async def generate(

try:
# generate completion
response: ChatCompletion = await self.client.chat.completions.create(
completion: ChatCompletion = await self.client.chat.completions.create(
**request
)

# save response for model_call
response = completion.model_dump()

# parse out choices
choices = self._chat_choices_from_response(response, tools)
choices = self._chat_choices_from_response(completion, tools)

# return output and call
return ModelOutput(
model=response.model,
model=completion.model,
choices=choices,
usage=(
ModelUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
input_tokens=completion.usage.prompt_tokens,
output_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
if response.usage
if completion.usage
else None
),
), ModelCall.create(
request=request,
response=response.model_dump(),
filter=image_url_filter,
)
), model_call()
except BadRequestError as e:
return self.handle_bad_request(e)
return self.handle_bad_request(e), model_call()

def _chat_choices_from_response(
self, response: ChatCompletion, tools: list[ToolInfo]
Expand Down
30 changes: 18 additions & 12 deletions src/inspect_ai/model/_providers/openai_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,32 @@ async def generate_o1(
messages=chat_messages(input, tools, handler),
**params,
)
response: dict[str, Any] = {}

def model_call() -> ModelCall:
return ModelCall.create(
request=request,
response=response,
)

try:
response: ChatCompletion = await client.chat.completions.create(**request)
completion: ChatCompletion = await client.chat.completions.create(**request)
response = completion.model_dump()
except BadRequestError as ex:
return handle_bad_request(model, ex)
return handle_bad_request(model, ex), model_call()

# return model output
return ModelOutput(
model=response.model,
choices=chat_choices_from_response(response, tools, handler),
model=completion.model,
choices=chat_choices_from_response(completion, tools, handler),
usage=ModelUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
input_tokens=completion.usage.prompt_tokens,
output_tokens=completion.usage.completion_tokens,
total_tokens=completion.usage.total_tokens,
)
if response.usage
if completion.usage
else None,
), ModelCall.create(
request=request,
response=response.model_dump(),
)
), model_call()


def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
Expand Down

0 comments on commit de06eb6

Please sign in to comment.