Skip to content

Commit

Permalink
Fix dataset loading and Pytorch 1.x compatibility
Browse files Browse the repository at this point in the history
Fixes dataset loading on Windows by specifying the encoding to be utf-8 on file open.
Fixes the usage of torch.compile for Pytorch 1.x by using a pass-through decorator if 2.x isn't detected.
  • Loading branch information
daniel-p-gonzalez authored Mar 14, 2023
1 parent 789c0ad commit 1c3e0fb
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
using_pytorch_2 = (int(torch.__version__.split('.')[0]) >= 2)
if not using_pytorch_2:
print("Info: Pytorch 2.0 isn't currently installed. Falling back to slower Pytorch 1.x pathway.")
torch_compile = lambda func: func
else:
torch_compile = torch.compile

## <-- teaching comments
# <-- functional comments
Expand Down Expand Up @@ -102,10 +105,10 @@
with zipfile.ZipFile('data_raw/data.zip', 'r') as zip_ref:
zip_ref.extractall('data_raw/')

with open('data_raw/wikitext-103-raw/wiki.train.raw', 'r') as data_file:
with open('data_raw/wikitext-103-raw/wiki.train.raw', 'r', encoding="utf8") as data_file:
raw_train_data = data_file.read()

with open('data_raw/wikitext-103-raw/wiki.valid.raw', 'r') as data_file:
with open('data_raw/wikitext-103-raw/wiki.valid.raw', 'r', encoding="utf8") as data_file:
raw_eval_data = data_file.read()

tokenizer = tiktoken.get_encoding("gpt2")
Expand Down Expand Up @@ -328,7 +331,7 @@ def init_split_parameter_dictionaries(net):

return params_non_decay, params_decay

@torch.compile
@torch_compile
def get_grad_norm(net):
# Gets the entire grad norm of the network.
grad_norm = torch.tensor(0., device=hyp['misc']['device'])
Expand Down Expand Up @@ -529,4 +532,4 @@ def main():
for i in range(5):
_, val_loss = main()
val_loss_list.append(val_loss)
print(f"Average final val loss: {sum(val_loss_list)/len(val_loss_list)}") # TODO add variance as well, later
print(f"Average final val loss: {sum(val_loss_list)/len(val_loss_list)}") # TODO add variance as well, later

0 comments on commit 1c3e0fb

Please sign in to comment.