Skip to content

Commit

Permalink
Fix loading of manually-download dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Feb 27, 2024
1 parent ab2fe81 commit 7c27af8
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 8 deletions.
8 changes: 1 addition & 7 deletions llmbox/dataset/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

logger = getLogger(__name__)

_dataset_class = {}


def import_dataset_class(dataset_name: str) -> Dataset:
if "wmt" in dataset_name:
Expand All @@ -31,17 +29,13 @@ def import_dataset_class(dataset_name: str) -> Dataset:

return Squad

if dataset_name in _dataset_class:
return _dataset_class[dataset_name]

module_path = __package__ + "." + dataset_name
module = importlib.import_module(module_path)
clsmembers = inspect.getmembers(module, inspect.isclass)

for name, obj in clsmembers:
if issubclass(obj, Dataset) and name.lower() == dataset_name.lower():
logger.debug(f"Dataset class `{name}` imported from `{module_path}`.")
_dataset_class[dataset_name] = obj
return obj

raise ValueError(
Expand Down Expand Up @@ -72,7 +66,6 @@ def load_dataset(args: "DatasetArguments", model: "Model", threading: bool = Tru
# TODO catch connection warning
if available_subsets == {"default"}:
available_subsets = set()
logger.debug(f"{name} - available_subsets: {available_subsets}, load_args: {dataset_cls.load_args}")

# for wmt, en-xx and xx-en are both supported
if "wmt" in args.dataset_name:
Expand All @@ -91,6 +84,7 @@ def load_dataset(args: "DatasetArguments", model: "Model", threading: bool = Tru

# use specified subset_names if available
subset_names = args.subset_names or available_subsets
logger.debug(f"{name} - available_subsets: {available_subsets}, load_args: {dataset_cls.load_args}, final subset_names: {subset_names}")

# GPTEval requires openai-gpt
if any(isinstance(m, GPTEval) for m in dataset_cls.metrics) and model.args.openai_api_key is None:
Expand Down
2 changes: 2 additions & 0 deletions llmbox/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def get_raw_dataset_loader(
if dataset_path is not None:
dataset_path = abspath(dataset_path)
msg += f" from local path `{dataset_path}`"
if subset_name is None and len(load_args) > 1 and load_args[1] is not None:
subset_name = load_args[1]

# load from a cloned repository from huggingface
if os.path.exists(os.path.join(dataset_path, "dataset_infos.json")):
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ langcodes
language_data
anthropic
google-api-python-client
prefetch_generator

0 comments on commit 7c27af8

Please sign in to comment.