From 1c3e0fbde823af32e197307a28c43f40c24c7dc1 Mon Sep 17 00:00:00 2001 From: Daniel Paul Gonzalez <427300+daniel-p-gonzalez@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:00:48 -0400 Subject: [PATCH] Fix dataset loading and Pytorch 1.x compatibility Fixes dataset loading on Windows by specifying the encoding to be utf-8 on file open. Fixes the usage of torch.compile for Pytorch 1.x by using a pass-through decorator if 2.x isn't detected. --- main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index d9074f8..c03fc28 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,9 @@ using_pytorch_2 = (int(torch.__version__.split('.')[0]) >= 2) if not using_pytorch_2: print("Info: Pytorch 2.0 isn't currently installed. Falling back to slower Pytorch 1.x pathway.") + torch_compile = lambda func: func +else: + torch_compile = torch.compile ## <-- teaching comments # <-- functional comments @@ -102,10 +105,10 @@ with zipfile.ZipFile('data_raw/data.zip', 'r') as zip_ref: zip_ref.extractall('data_raw/') - with open('data_raw/wikitext-103-raw/wiki.train.raw', 'r') as data_file: + with open('data_raw/wikitext-103-raw/wiki.train.raw', 'r', encoding="utf8") as data_file: raw_train_data = data_file.read() - with open('data_raw/wikitext-103-raw/wiki.valid.raw', 'r') as data_file: + with open('data_raw/wikitext-103-raw/wiki.valid.raw', 'r', encoding="utf8") as data_file: raw_eval_data = data_file.read() tokenizer = tiktoken.get_encoding("gpt2") @@ -328,7 +331,7 @@ def init_split_parameter_dictionaries(net): return params_non_decay, params_decay -@torch.compile +@torch_compile def get_grad_norm(net): # Gets the entire grad norm of the network. grad_norm = torch.tensor(0., device=hyp['misc']['device']) @@ -529,4 +532,4 @@ def main(): for i in range(5): _, val_loss = main() val_loss_list.append(val_loss) - print(f"Average final val loss: {sum(val_loss_list)/len(val_loss_list)}") # TODO add variance as well, later \ No newline at end of file + print(f"Average final val loss: {sum(val_loss_list)/len(val_loss_list)}") # TODO add variance as well, later