Skip to content

Commit

Permalink
ensure &= handled to avoid dict vs. bool comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Aug 8, 2024
1 parent 7435b4b commit 9999b08
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
15 changes: 8 additions & 7 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
have_langchain, set_openai, cuda_vis_check, H2O_Fire, lg_to_gr, str_to_list, str_to_dict, get_token_count, \
have_wavio, have_soundfile, have_deepspeed, have_doctr, have_librosa, have_TTS, have_flash_attention_2, \
have_diffusers, sanitize_filename, get_gradio_tmp, get_is_gradio_h2oai, get_json, \
get_docs_tokens, deduplicate_names, have_autogen, get_model_name
get_docs_tokens, deduplicate_names, have_autogen, get_model_name, is_empty

start_faulthandler()
import_matplotlib()
Expand Down Expand Up @@ -2959,14 +2959,15 @@ def evaluate(
schema_instruction = json_schema_instruction.format(properties_schema=guided_json_properties_json)

pre_instruction = ''
supports_schema = guided_json and response_format == 'json_object' and is_json_model(base_model,
inference_server,
json_vllm=json_vllm)
supports_schema = not is_empty(guided_json) and \
response_format == 'json_object' and \
is_json_model(base_model, inference_server, json_vllm=json_vllm)
supports_schema &= json_vllm or \
inference_server and \
not is_empty(inference_server) and \
any(inference_server.startswith(x) for x in ['openai_chat', 'openai_azure_chat']) and \
base_model in openai_supports_functiontools + openai_supports_parallel_functiontools or \
inference_server and \
not is_empty(
base_model) and base_model in openai_supports_functiontools + openai_supports_parallel_functiontools or \
not is_empty(inference_server) and \
inference_server.startswith('anthropic')

if supports_schema:
Expand Down
36 changes: 36 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import traceback
import zipfile
import tarfile
from array import array
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import Tuple, Callable, Dict
Expand Down Expand Up @@ -2985,3 +2987,37 @@ def has_single_element_sublist(lst, depth):
if has_single_element_sublist(lst, depth):
depth -= 1
return depth


def is_empty(obj):
if obj is None:
return True
if isinstance(obj, (str, list, tuple, dict, set)):
return len(obj) == 0
if isinstance(obj, bool):
return False
if isinstance(obj, (int, float)):
# Numbers can't be "empty" in the traditional sense, so go by value for them
return False if 0 else True
if isinstance(obj, complex):
return obj == 0
if isinstance(obj, bytes):
return len(obj) == 0
if isinstance(obj, bytearray):
return len(obj) == 0
if isinstance(obj, memoryview):
return len(obj) == 0
if isinstance(obj, range):
return len(obj) == 0
if isinstance(obj, frozenset):
return len(obj) == 0
if isinstance(obj, deque):
return len(obj) == 0
if isinstance(obj, array):
return len(obj) == 0
if isinstance(obj, (map, filter, zip)):
# These are iterators and need to be converted to a list to check if they are empty
return len(list(obj)) == 0
if hasattr(obj, '__len__'):
return len(obj) == 0
return False
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "bd7b35638f2e395e518d5f59ac92d214bbed15fd"
__version__ = "7435b4bc4c0e6559fd90e89f7a3f51f9353ccf89"

0 comments on commit 9999b08

Please sign in to comment.