From c2305e1a950d7b3ca335b9af4654805723bffa27 Mon Sep 17 00:00:00 2001 From: Dmytro Chaplynskyi Date: Sun, 22 Oct 2023 21:25:47 +0000 Subject: [PATCH] Switching run_mlm to xla, wrapping things --- .gitignore | 1 + run_mlm.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 620b531..2b91396 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ eval/preprocess/ data/ bruk_corpus/ .DS_Store +exps/ diff --git a/run_mlm.py b/run_mlm.py index e7b3040..731d3e8 100644 --- a/run_mlm.py +++ b/run_mlm.py @@ -48,6 +48,7 @@ from transformers.utils import check_min_version from transformers.utils.versions import require_version import wandb +import torch_xla.core.xla_model as xm # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.18.0.dev0") @@ -291,7 +292,7 @@ def main(): for (name, path) in [("train", data_args.train_file), ("validation", data_args.validation_file)]: if path.endswith(".txt"): - raw_datasets[name] = datasets.Dataset.from_dict({"text": [open(path).read()], "id": [0]}) + raw_datasets[name] = datasets.Dataset.from_dict({"text": [open(path).read()], "compound_id": [0], "id": [0]}) else: raw_datasets[name] = datasets.load_from_disk(data_args.train_file) @@ -351,6 +352,7 @@ def main(): model = AutoModelForMaskedLM.from_config(config) model.resize_token_embeddings(len(tokenizer)) + model.to(xm.xla_device()) # Preprocessing the datasets. # First we tokenize all the texts.