forked from SylphAI-Inc/AdalFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request SylphAI-Inc#371 from SylphAI-Inc/fix2
add gsm8k dataset
- Loading branch information
Showing
11 changed files
with
370 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import random | ||
import os | ||
from typing import Literal | ||
import tqdm | ||
|
||
from adalflow.utils.lazy_import import safe_import, OptionalPackages | ||
|
||
|
||
from adalflow.utils.data import Dataset | ||
from adalflow.utils.file_io import save_json, load_json | ||
from adalflow.datasets.utils import prepare_dataset_path | ||
from adalflow.core.base_data_class import DataClass | ||
from adalflow.datasets.types import GSM8KData | ||
from adalflow.utils import printc | ||
|
||
|
||
class GSM8K(Dataset): | ||
__doc__ = r""" Use huggingface datasets to load GSM8K dataset. | ||
official_train: 7473 | ||
official_test: 1319 | ||
Our train split: 3736/2 | ||
Our val split: 3736/2 | ||
Our test split: 1319 | ||
You can use size to limit the number of examples to load. | ||
Example: | ||
.. code-block:: python | ||
dataset = GSM8K(split="train", size=10) | ||
print(f"example: {dataset[0]}") | ||
The output will be: | ||
.. code-block:: | ||
GSM8KData(id='8fc791e6-ea1d-472c-a882-d00d0600d423', | ||
question="The result from the 40-item Statistics exam Marion and Ella took already came out. | ||
Ella got 4 incorrect answers while Marion got 6 more than half the score of Ella. | ||
What is Marion's score?", | ||
answer='24', | ||
gold_reasoning="Ella's score is 40 items - 4 items = <<40-4=36>>36 items. | ||
Half of Ella's score is 36 items / 2 = <<36/2=18>>18 items. | ||
So, Marion's score is 18 items + 6 items = <<18+6=24>>24 items.", | ||
reasoning=None) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
root: str = None, | ||
split: Literal["train", "val", "test"] = "train", | ||
size: int = None, | ||
**kwargs, | ||
) -> None: | ||
|
||
if split not in ["train", "val", "test"]: | ||
raise ValueError("Split must be one of 'train', 'val', 'test'") | ||
|
||
self.root = root | ||
self.task_name = "gsm8k" | ||
data_path = prepare_dataset_path(self.root, self.task_name) | ||
# download and save | ||
split_csv_path = os.path.join(data_path, f"{split}.json") | ||
print(f"split_csv_path: {split_csv_path}") | ||
self._check_or_download_dataset(split_csv_path, split) | ||
|
||
# load from csv | ||
self.data = [] | ||
|
||
self.data = load_json(split_csv_path) | ||
if size is not None: | ||
self.data = self.data[:size] | ||
# convert to dataclass | ||
self.data = [GSM8KData.from_dict(d) for d in self.data] | ||
|
||
def _check_or_download_dataset( | ||
self, | ||
data_path: str = None, | ||
split: str = "train", | ||
): | ||
r"""It will download data from huggingface datasets and split it and save it into three csv files. | ||
Args: | ||
data_path (str): The path to save the data. In particular with split name appended. | ||
split (str): The dataset split, supports ``"train"`` (default), ``"val"`` and ``"test"``. Decides which split to return. | ||
only_hard_examples (bool): If True, only hard examples will be downloaded. | ||
keep_details (str): If "all", all details will be kept. If "dev_titles", only dev titles will be kept. | ||
""" | ||
|
||
if data_path is None: | ||
raise ValueError("data_path must be specified") | ||
|
||
if os.path.exists(data_path): | ||
return | ||
|
||
safe_import( | ||
OptionalPackages.DATASETS.value[0], OptionalPackages.DATASETS.value[1] | ||
) | ||
from datasets import load_dataset | ||
|
||
# use huggingface cache | ||
gsm8k_dataset = load_dataset("gsm8k", "main", cache_dir=self.root) | ||
|
||
hf_official_train = gsm8k_dataset["train"] | ||
hf_official_test = gsm8k_dataset["test"] | ||
|
||
official_train = [] | ||
official_test = [] | ||
|
||
for example in tqdm.tqdm(hf_official_train): | ||
question = example["question"] | ||
answer = example["answer"].strip().split() | ||
assert answer[-2] == "####" | ||
|
||
gold_reasoning = " ".join(answer[:-2]) | ||
answer = str(int(answer[-1].replace(",", ""))) | ||
official_train.append( | ||
dict(question=question, gold_reasoning=gold_reasoning, answer=answer) | ||
) | ||
|
||
for example in tqdm.tqdm(hf_official_test): | ||
question = example["question"] | ||
answer = example["answer"].strip().split() | ||
assert answer[-2] == "####" | ||
|
||
gold_reasoning = " ".join(answer[:-2]) | ||
answer = str(int(answer[-1].replace(",", ""))) | ||
official_test.append( | ||
dict(question=question, gold_reasoning=gold_reasoning, answer=answer) | ||
) | ||
|
||
rng = random.Random(0) | ||
rng.shuffle(official_train) # 7473 train | ||
rng = random.Random(0) | ||
rng.shuffle(official_test) # 1319 test | ||
|
||
printc(f"official_train: {len(official_train)}") | ||
printc(f"official_test: {len(official_test)}") | ||
train_set = official_train[: len(official_train) * 50 // 100] | ||
val_set = official_train[len(official_train) * 50 // 100 :] | ||
data_path_dir = os.path.dirname(data_path) | ||
for split, examples in zip( | ||
["train", "val", "test"], | ||
[train_set, val_set, official_test], | ||
): | ||
target_path = os.path.join(data_path_dir, f"{split}.json") | ||
save_json(examples, f=target_path) | ||
|
||
if split == "train": | ||
return train_set | ||
elif split == "val": | ||
return val_set | ||
else: | ||
return official_test | ||
|
||
def __getitem__(self, index) -> DataClass: | ||
return self.data[index] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
if __name__ == "__main__": | ||
dataset = GSM8K(split="train", size=10) | ||
|
||
print(f"len: {len(dataset)}") | ||
print(f"dataset[0]: {dataset[0]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Dict, Union | ||
import re | ||
import adalflow as adal | ||
|
||
template = r"""<START_OF_SYSTEM_PROMPT> | ||
{{system_prompt}} | ||
<END_OF_SYSTEM_PROMPT> | ||
<START_OF_USER_PROMPT> | ||
{{input_str}} | ||
<END_OF_USER_PROMPT> | ||
""" | ||
|
||
system_prompt_start = "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value." | ||
|
||
|
||
@adal.func_to_data_component | ||
def parse_integer_answer(answer: str) -> str: | ||
try: | ||
numbers = re.findall(r"\d+", answer) | ||
if numbers: | ||
answer = numbers[-1] | ||
else: | ||
answer = "" | ||
except ValueError: | ||
answer = "" | ||
|
||
return answer | ||
|
||
|
||
class GSM8KTask(adal.Component): | ||
def __init__(self, model_client: adal.ModelClient, model_kwargs: Dict): | ||
super().__init__() | ||
|
||
system_prompt = adal.Parameter( | ||
data=system_prompt_start, | ||
role_desc="To give task instruction to the language model in the system prompt", | ||
requires_opt=True, | ||
param_type=adal.ParameterType.PROMPT, | ||
) | ||
self.generator = adal.Generator( | ||
model_client=model_client, | ||
model_kwargs=model_kwargs, | ||
prompt_kwargs={ | ||
"system_prompt": system_prompt, | ||
}, | ||
template=template, | ||
output_processors=parse_integer_answer, | ||
use_cache=True, | ||
) | ||
|
||
def bicall( | ||
self, question: str, id: str = None | ||
) -> Union[adal.GeneratorOutput, adal.Parameter]: | ||
output = self.generator(prompt_kwargs={"input_str": question}, id=id) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
from adalflow.utils import setup_env | ||
from adalflow.datasets.gsm8k import GSM8K | ||
|
||
setup_env() | ||
|
||
from use_cases.config import gpt_3_model | ||
|
||
task = GSM8KTask(**gpt_3_model) | ||
|
||
train_dataset = GSM8K(split="train", size=10) | ||
|
||
print("example: ", train_dataset[0]) | ||
|
||
output = task(question=train_dataset[0].question) | ||
print("output: ", output) |
Oops, something went wrong.