Skip to content

Commit

Permalink
Update datasets - mind2web and new synthetic ds
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-healey committed Oct 3, 2023
1 parent 08c0657 commit 3c0142e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/llama2d/datasets/mind2web.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class Mind2webDataset(Dataset):
def __init__(
self, model="decapoda-research/llama-7b-hf", playwright=None, headless=False
self, model="decapoda-research/llama-7b-hf", playwright=None, headless=False,show_errors=False
):
assert playwright is not None, "Please pass in playwright"
self.__extractor = Llama2dWebsiteFeatureExtractor(mask_out_body=True)
Expand Down Expand Up @@ -53,6 +53,8 @@ def __init__(
"Safari/537.36"
}
)
self.page.set_default_navigation_timeout(1000 * 10)
self.show_errors = show_errors

def __len__(self):
return len(self.actions)
Expand Down Expand Up @@ -118,13 +120,15 @@ def __getitem__(self, index):
return ret
except Exception as e:
# raise e
print("Error in dataset:", e)
if self.show_errors:
print("Error in dataset:", str(e)[:100] + "...")

if "ImageAnnotation" in str(e):
raise e

if screenshot_path is not None:
os.remove(screenshot_path)
if os.path.exists(screenshot_path):
os.remove(screenshot_path)
return None

def get_uid_to_mhtml_map(self) -> Dict[str, str]:
Expand All @@ -148,9 +152,9 @@ def get_uid(path):
)

with sync_playwright() as playwright:
dataset = Mind2webDataset(playwright=playwright,headless=True)
dataset = Mind2webDataset(playwright=playwright,headless=True,show_errors=False)

debug_dataset(dataset)
# debug_dataset(dataset)

# publish a subset
num_samples = 2_000
Expand Down
69 changes: 69 additions & 0 deletions src/llama2d/datasets/synthetic/unscramble_words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

from llama2d.vision import debug_dataset,Llama2dTokenizer,Llama2dScreen
from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
from torch.utils.data import Dataset

from random import choice,random
rand_words = "bob,jane,alice,carol,ted,lisa,barry,frank,george,harold,henry,ian,john,james,kevin,mark,neil,oliver,peter,quinn,robert,steve,thomas,william".split(",")

class UnscrambleDataset(Dataset):
def __init__(
self,
num_screens:int,
words_per_screen:int,
words_per_line:int=20,
lines_per_screen:int=5,
tokenizer:Llama2dTokenizer=None
):
self.num_screens = num_screens
self.words_per_screen = words_per_screen

if tokenizer is None:
tokenizer = Llama2dTokenizer()
self.tokenizer = tokenizer

self.screens = []
for i in range(num_screens):
screen = Llama2dScreen()

words = [choice(rand_words) for _ in range(words_per_screen)]

# render in a grid of lines
for k,word in enumerate(words):
i,j = k%words_per_line,k//words_per_line
# convert i,j to x,y, where x is horizontal and y is vertical
# x is in [0,1] and y is in [0,1]

x = (i+0.5)/words_per_line
y = (j+0.5)/lines_per_screen

assert y<1,"Too many words for the screen"

screen.push_word(word=word,xy=(x,y))

from random import shuffle
shuffle(screen.words)

prompt = "Read out the words in the order they appear."
response = " ".join(words)

self.screens.append(self.tokenizer.process(prompt,screen,response))

def __len__(self):
return self.num_screens
def __getitem__(self,i:int):
return self.screens[i]

if __name__ == "__main__":

dataset = UnscrambleDataset(
num_screens=500,
words_per_screen=50,
words_per_line=15,
lines_per_screen=5
)

debug_dataset(dataset)

info = DatasetInfo(repo="llama2d/llama2d-unscramble",desc="Unscramble the words displayed on the screen.")
publish_pt_dataset(dataset,info)

0 comments on commit 3c0142e

Please sign in to comment.